Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

0
api/app/__init__.py Normal file
View File

View File

@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success from app.core.response_utils import success
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -32,11 +33,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel): class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型""" """情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel): class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型""" """情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取") emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID") emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词") emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
def get_emotion_config( def get_emotion_config(
config_id: int = Query(..., description="配置ID"), config_id: UUID = Query(..., description="配置ID"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):

View File

@@ -53,7 +53,7 @@ async def get_emotion_tags(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计", f"用户 {current_user.username} 请求获取情绪标签统计",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"emotion_type": request.emotion_type, "emotion_type": request.emotion_type,
"start_date": request.start_date, "start_date": request.start_date,
"end_date": request.end_date, "end_date": request.end_date,
@@ -63,7 +63,7 @@ async def get_emotion_tags(
# 调用服务层 # 调用服务层
data = await emotion_service.get_emotion_tags( data = await emotion_service.get_emotion_tags(
end_user_id=request.group_id, end_user_id=request.end_user_id,
emotion_type=request.emotion_type, emotion_type=request.emotion_type,
start_date=request.start_date, start_date=request.start_date,
end_date=request.end_date, end_date=request.end_date,
@@ -73,7 +73,7 @@ async def get_emotion_tags(
api_logger.info( api_logger.info(
"情绪标签统计获取成功", "情绪标签统计获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"total_count": data.get("total_count", 0), "total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", [])) "tags_count": len(data.get("tags", []))
} }
@@ -84,7 +84,7 @@ async def get_emotion_tags(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪标签统计失败: {str(e)}", f"获取情绪标签统计失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据", f"用户 {current_user.username} 请求获取情绪词云数据",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"emotion_type": request.emotion_type, "emotion_type": request.emotion_type,
"limit": request.limit "limit": request.limit
} }
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
# 调用服务层 # 调用服务层
data = await emotion_service.get_emotion_wordcloud( data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.group_id, end_user_id=request.end_user_id,
emotion_type=request.emotion_type, emotion_type=request.emotion_type,
limit=request.limit limit=request.limit
) )
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
api_logger.info( api_logger.info(
"情绪词云数据获取成功", "情绪词云数据获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"total_keywords": data.get("total_keywords", 0) "total_keywords": data.get("total_keywords", 0)
} }
) )
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪词云数据失败: {str(e)}", f"获取情绪词云数据失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -159,21 +159,21 @@ async def get_emotion_health(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数", f"用户 {current_user.username} 请求获取情绪健康指数",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"time_range": request.time_range "time_range": request.time_range
} }
) )
# 调用服务层 # 调用服务层
data = await emotion_service.calculate_emotion_health_index( data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.group_id, end_user_id=request.end_user_id,
time_range=request.time_range time_range=request.time_range
) )
api_logger.info( api_logger.info(
"情绪健康指数获取成功", "情绪健康指数获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"health_score": data.get("health_score", 0), "health_score": data.get("health_score", 0),
"level": data.get("level", "未知") "level": data.get("level", "未知")
} }
@@ -186,7 +186,7 @@ async def get_emotion_health(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪健康指数失败: {str(e)}", f"获取情绪健康指数失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
"""获取个性化情绪建议(从缓存读取) """获取个性化情绪建议(从缓存读取)
Args: Args:
request: 包含 group_id 和可选的 config_id request: 包含 end_user_id 和可选的 config_id
db: 数据库会话 db: 数据库会话
current_user: 当前用户 current_user: 当前用户
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)", f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"config_id": request.config_id "config_id": request.config_id
} }
) )
# 从缓存获取建议 # 从缓存获取建议
data = await emotion_service.get_cached_suggestions( data = await emotion_service.get_cached_suggestions(
end_user_id=request.group_id, end_user_id=request.end_user_id,
db=db db=db
) )
if data is None: if data is None:
# 缓存不存在或已过期 # 缓存不存在或已过期
api_logger.info( api_logger.info(
f"用户 {request.group_id} 的建议缓存不存在或已过期", f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id} extra={"end_user_id": request.end_user_id}
) )
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
api_logger.info( api_logger.info(
"个性化建议获取成功(缓存)", "个性化建议获取成功(缓存)",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", [])) "suggestions_count": len(data.get("suggestions", []))
} }
) )
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取个性化建议失败: {str(e)}", f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(

View File

@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0") raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse) @router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_preference_tags( async def get_preference_tags(
user_id: str, end_user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"), confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"), tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"), start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -137,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache. Get user preference tags from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
confidence_threshold: Minimum confidence score (0.0-1.0) confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter tag_category: Optional category filter
start_date: Optional start date filter start_date: Optional start date filter
@@ -146,20 +146,20 @@ async def get_preference_tags(
Returns: Returns:
List of preference tags from cache List of preference tags from cache
""" """
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)") api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -192,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref) filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)") api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)") return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id) return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse) @router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_dimension_portrait( async def get_dimension_portrait(
user_id: str, end_user_id: str,
include_history: bool = Query(False, description="Include historical trends"), include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache. Get user's four-dimension personality portrait from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
include_history: Whether to include historical trend data (ignored for cached data) include_history: Whether to include historical trend data (ignored for cached data)
Returns: Returns:
Four-dimension personality portrait from cache Four-dimension personality portrait from cache
""" """
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)") api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
# Extract portrait from cache # Extract portrait from cache
portrait = cached_profile.get("portrait", {}) portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)") api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)") return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id) return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse) @router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_interest_area_distribution( async def get_interest_area_distribution(
user_id: str, end_user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"), include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache. Get user's interest area distribution from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
include_trends: Whether to include trend analysis data (ignored for cached data) include_trends: Whether to include trend analysis data (ignored for cached data)
Returns: Returns:
Interest area distribution from cache Interest area distribution from cache
""" """
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)") api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
# Extract interest areas from cache # Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {}) interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)") api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)") return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id) return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse) @router.get("/habits/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_behavior_habits( async def get_behavior_habits(
user_id: str, end_user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"), confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"), frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"), time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -309,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache. Get user's behavioral habits from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
confidence_level: Filter by confidence level (high, medium, low) confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered) frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past) time_period: Filter by time period (current, past)
@@ -317,20 +317,20 @@ async def get_behavior_habits(
Returns: Returns:
List of behavioral habits from cache List of behavioral habits from cache
""" """
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)") api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -368,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit) filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)") api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)") return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id) return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)

View File

@@ -125,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Response with write operation status Response with write operation status
@@ -160,19 +160,18 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
user_input.group_id, user_input.end_user_id,
messages_list, # 传递结构化消息列表 messages_list,
config_id, config_id,
db, db,
storage_type, storage_type,
user_rag_memory_id user_rag_memory_id
) )
return success(data=result, msg="写入成功") return success(data=result, msg="写入成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -196,7 +195,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Task ID for tracking async operation Task ID for tracking async operation
@@ -229,7 +228,7 @@ async def write_server_async(
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.write_message", "app.core.memory.agent.write_message",
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
@@ -255,7 +254,7 @@ async def read_server(
- "2": Direct answer based on context - "2": Direct answer based on context
Args: Args:
user_input: Read request with message, history, search_switch, and group_id user_input: Read request with message, history, search_switch, and end_user_id
Returns: Returns:
Response with query answer Response with query answer
@@ -277,12 +276,13 @@ async def read_server(
name="USER_RAG_MERORY", name="USER_RAG_MERORY",
workspace_id=workspace_id workspace_id=workspace_id
) )
if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.group_id, user_input.end_user_id,
user_input.message, user_input.message,
user_input.history, user_input.history,
user_input.search_switch, user_input.search_switch,
@@ -293,12 +293,12 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve( result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
group_id=user_input.group_id, end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info, retrieve_info=retrieve_info,
history=history, history=history,
query=query, query=query,
@@ -404,7 +404,7 @@ async def read_server_async(
try: try:
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.read_message", "app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id] config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Read task queued: {task.id}") api_logger.info(f"Read task queued: {task.id}")
@@ -448,7 +448,7 @@ async def get_read_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -525,7 +525,7 @@ async def get_write_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -579,12 +579,12 @@ async def status_type(
Determine the type of user message (read or write) Determine the type of user message (read or write)
Args: Args:
user_input: Request containing user message and group_id user_input: Request containing user message and end_user_id
Returns: Returns:
Type classification result Type classification result
""" """
api_logger.info(f"Status type check requested for group {user_input.group_id}") api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try: try:
# 获取标准化的消息列表 # 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
@@ -625,7 +625,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。 会对缺失类型补 0返回字典形式。
可选按状态过滤。 可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0 - 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
@@ -698,7 +698,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取工作空间下Popular Memory Tags,包含: 获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id - name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结 - tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签 - hot_tags: 4个热门记忆标签

View File

@@ -11,6 +11,7 @@
""" """
from typing import Optional from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -106,7 +107,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期 # 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle( report = await forget_service.trigger_forgetting_cycle(
db=db, db=db,
group_id=end_user_id, # 服务层方法的参数名是 group_id end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
max_merge_batch_size=payload.max_merge_batch_size, max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access, min_days_since_access=payload.min_days_since_access,
config_id=config_id config_id=config_id
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config( async def read_forgetting_config(
config_id: int, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -236,7 +237,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse) @router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats( async def get_forgetting_stats(
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -246,7 +247,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。 返回知识层节点统计、激活值分布等信息。
Args: Args:
group_id: 组ID即 end_user_id可选 end_user_id: 组ID即 end_user_id可选
current_user: 当前用户 current_user: 当前用户
db: 数据库会话 db: 数据库会话
@@ -260,20 +261,20 @@ async def get_forgetting_stats(
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id # 如果提供了 end_user_id通过它获取 config_id
config_id = None config_id = None
if group_id: if end_user_id:
try: try:
from app.services.memory_agent_service import get_end_user_connected_config from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if config_id is None: if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置") api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None") return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}") api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e: except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}") api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -283,14 +284,14 @@ async def get_forgetting_stats(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: " f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}" f"end_user_id={end_user_id}, config_id={config_id}"
) )
try: try:
# 调用服务层获取统计信息 # 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats( stats = await forget_service.get_forgetting_stats(
db=db, db=db,
group_id=group_id, end_user_id=end_user_id,
config_id=config_id config_id=config_id
) )

View File

@@ -27,27 +27,27 @@ router = APIRouter(
) )
@router.get("/{group_id}/count", response_model=ApiResponse) @router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count( def get_memory_count(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve perceptual memory statistics for a user group. """Retrieve perceptual memory statistics for a user group.
Args: Args:
group_id: ID of the user group (usually end_user_id in this context) end_user_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Response containing memory count statistics ApiResponse: Response containing memory count statistics
""" """
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id) count_stats = service.get_memory_count(end_user_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}") api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics", msg="Failed to fetch memory statistics",
) )
@router.get("/{group_id}/last_visual", response_model=ApiResponse) @router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory( def get_last_visual_memory(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent VISION-type memory for a user. """Retrieve the most recent VISION-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest visual memory ApiResponse: Metadata of the latest visual memory
""" """
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id) visual_memory = service.get_latest_visual_memory(end_user_id)
if visual_memory is None: if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}") api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No visual memory available" msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory", msg="Failed to fetch latest visual memory",
) )
@router.get("/{group_id}/last_listen", response_model=ApiResponse) @router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen( def get_last_memory_listen(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent AUDIO-type memory for a user. """Retrieve the most recent AUDIO-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest audio memory ApiResponse: Metadata of the latest audio memory
""" """
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id) audio_memory = service.get_latest_audio_memory(end_user_id)
if audio_memory is None: if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}") api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No audio memory available" msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory", msg="Failed to fetch latest audio memory",
) )
@router.get("/{group_id}/last_text", response_model=ApiResponse) @router.get("/{end_user_id}/last_text", response_model=ApiResponse)
def get_last_text_memory( def get_last_text_memory(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent TEXT-type memory for a user. """Retrieve the most recent TEXT-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest text memory ApiResponse: Metadata of the latest text memory
""" """
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
# 调用服务层获取最近的文本记忆 # 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id) text_memory = service.get_latest_text_memory(end_user_id)
if text_memory is None: if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}") api_logger.info(f"No text memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No text memory available" msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory", msg="Failed to fetch latest text memory",
) )
@router.get("/{group_id}/timeline", response_model=ApiResponse) @router.get("/{end_user_id}/timeline", response_model=ApiResponse)
def get_memory_time_line( def get_memory_time_line(
group_id: uuid.UUID, end_user_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"), perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"), page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group. """Retrieve a timeline of perceptual memories for a user group.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
perceptual_type: Optional filter for perceptual type perceptual_type: Optional filter for perceptual type
page: Page number for pagination page: Page number for pagination
page_size: Number of items per page page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
""" """
api_logger.info( api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, " f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}" f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
) )
try: try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
) )
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query) timeline_data = service.get_time_line(end_user_id, query)
api_logger.info( api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, " f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, " f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"error={str(e)}" f"error={str(e)}"
) )
return fail( return fail(

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import time import time
import uuid import uuid
from uuid import UUID
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import ( from app.core.memory.storage_services.reflection_engine.self_reflexion import (
@@ -11,7 +12,7 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import ( from app.services.memory_reflection_service import (
@@ -50,7 +51,7 @@ async def save_reflection_config(
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
data_config = DataConfigRepository.update_reflection_config( memory_config = MemoryConfigRepository.update_reflection_config(
db, db,
config_id=config_id, config_id=config_id,
enable_self_reflexion=request.reflection_enabled, enable_self_reflexion=request.reflection_enabled,
@@ -63,17 +64,17 @@ async def save_reflection_config(
) )
db.commit() db.commit()
db.refresh(data_config) db.refresh(memory_config)
reflection_result={ reflection_result={
"config_id": data_config.config_id, "config_id": memory_config.config_id,
"enable_self_reflexion": data_config.enable_self_reflexion, "enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": data_config.iteration_period, "iteration_period": memory_config.iteration_period,
"reflexion_range": data_config.reflexion_range, "reflexion_range": memory_config.reflexion_range,
"baseline": data_config.baseline, "baseline": memory_config.baseline,
"reflection_model_id": data_config.reflection_model_id, "reflection_model_id": memory_config.reflection_model_id,
"memory_verify": data_config.memory_verify, "memory_verify": memory_config.memory_verify,
"quality_assessment": data_config.quality_assessment} "quality_assessment": memory_config.quality_assessment}
return success(data=reflection_result, msg="反思配置成功") return success(data=reflection_result, msg="反思配置成功")
@@ -111,14 +112,14 @@ async def start_workspace_reflection(
reflection_results = [] reflection_results = []
for data in result['apps_detailed_info']: for data in result['apps_detailed_info']:
if data['data_configs'] == []: if data['memory_configs'] == []:
continue continue
releases = data['releases'] releases = data['releases']
data_configs = data['data_configs'] memory_configs = data['memory_configs']
end_users = data['end_users'] end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users): for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况 # 安全地转换为整数处理空字符串和None的情况
print(base['config']) print(base['config'])
try: try:
@@ -156,14 +157,14 @@ async def start_workspace_reflection(
@router.get("/reflection/configs") @router.get("/reflection/configs")
async def start_reflection_configs( async def start_reflection_configs(
config_id: int, config_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""通过config_id查询data_config表中的反思配置信息""" """通过config_id查询memory_config表中的反思配置信息"""
try: try:
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = DataConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 构建返回数据 # 构建返回数据
reflection_config = { reflection_config = {
"config_id": result.config_id, "config_id": result.config_id,
@@ -191,7 +192,7 @@ async def start_reflection_configs(
@router.get("/reflection/run") @router.get("/reflection/run")
async def reflection_run( async def reflection_run(
config_id: int, config_id: UUID,
language_type: str = Header(default="zh", alias="X-Language-Type"), language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -200,8 +201,8 @@ async def reflection_run(
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置 # 使用MemoryConfigRepository查询反思配置
result = DataConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result: if not result:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,

View File

@@ -1,5 +1,6 @@
import os import os
from typing import Optional from typing import Optional
from uuid import UUID
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
@@ -160,7 +161,7 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
@@ -232,7 +233,7 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:

View File

@@ -20,18 +20,18 @@ router = APIRouter(
) )
@router.get("/{group_id}/count", response_model=ApiResponse) @router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count( def get_memory_count(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
pass pass
@router.get("/{group_id}/conversations", response_model=ApiResponse) @router.get("/{end_user_id}/conversations", response_model=ApiResponse)
def get_conversations( def get_conversations(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -39,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group. Retrieve all conversations for the current user in a specific group.
Args: Args:
group_id (UUID): The group identifier. end_user_id (UUID): The group identifier.
current_user (User, optional): The authenticated user. current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session. db (Session, optional): SQLAlchemy session.
@@ -53,7 +53,7 @@ def get_conversations(
""" """
conversation_service = ConversationService(db) conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations( conversations = conversation_service.get_user_conversations(
group_id end_user_id
) )
return success(data=[ return success(data=[
{ {
@@ -63,7 +63,7 @@ def get_conversations(
], msg="get conversations success") ], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse) @router.get("/{end_user_id}/messages", response_model=ApiResponse)
def get_messages( def get_messages(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -100,7 +100,7 @@ def get_messages(
return success(data=messages, msg="get conversation history success") return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse) @router.get("/{end_user_id}/detail", response_model=ApiResponse)
async def get_conversation_detail( async def get_conversation_detail(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service. Stores memory content for the specified end user using the Memory API Service.
""" """
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}") logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db) memory_api_service = MemoryAPIService(db)

View File

@@ -135,27 +135,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id end_user_id = request.end_user_id
api_logger.info( api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}" f"end_user_id={end_user_id if end_user_id else '全部用户'}"
) )
try: try:
if group_id: if end_user_id:
# 为单个用户生成 # 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察 # 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
# 生成用户摘要 # 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
# 构建响应 # 构建响应
result = { result = {
"end_user_id": group_id, "end_user_id": end_user_id,
"insight_success": insight_result["success"], "insight_success": insight_result["success"],
"summary_success": summary_result["success"], "summary_success": summary_result["success"],
"errors": [] "errors": []
@@ -175,9 +175,9 @@ async def generate_cache_api(
# 记录结果 # 记录结果
if result["insight_success"] and result["summary_success"]: if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存") api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
else: else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成") return success(data=result, msg="生成完成")

View File

@@ -155,7 +155,7 @@ class LangChainAgent:
# userid=end_user_end, # userid=end_user_end,
# messages=messages, # messages=messages,
# apply_id=end_user_end, # apply_id=end_user_end,
# group_id=end_user_end, # end_user_id=end_user_end,
# aimessages=aimessages # aimessages=aimessages
# ) # )
# store.delete_duplicate_sessions() # store.delete_duplicate_sessions()
@@ -228,7 +228,7 @@ class LangChainAgent:
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段 # 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay( write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j" storage_type, # storage_type: "neo4j"

View File

@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """问题分解节点"""
# 从状态中获取数据 # 从状态中获取数据
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
start = time.time() start = time.time()
content = state.get('data', '') content = state.get('data', '')
data = state.get('spit_data', '')['context'] data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {} databasets = {}
data = [] data = []
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config return kb_config
async def rag_knowledge(state,question): async def rag_knowledge(state,question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge) clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context'] problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '') end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original=state.get('data', '')
problem_list=[] problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try: try:
# Prepare search parameters based on storage type # Prepare search parameters based on storage type
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": question, "question": question,
"return_raw_results": True "return_raw_results": True
} }
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id # 从state中获取end_user_id
import time import time
start=time.time() start=time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original = state.get('data', '') original = state.get('data', '')
problem_list = [] problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2, temperature=0.2,
) )
time_retrieval_tool = create_time_retrieval_tool(group_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "group_id": group_id, "return_raw_results": True } search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题

View File

@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState: async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
user_id=group_id, user_id=end_user_id,
query=data, query=data,
apply_id=group_id, apply_id=end_user_id,
group_id=group_id, end_user_id=end_user_id,
ai_response=aimessages ai_response=aimessages
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '') data=state.get("data", '')
group_id=state.get("group_id", '') end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history( state)
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance "include": ["summaries"] # Only search summary nodes for faster performance

View File

@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})

View File

@@ -1,21 +1,22 @@
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState: async def write_node(state: WriteState) -> WriteState:
""" """
Write data to the database/file system. Write data to the database/file system.
Args: Args:
state: WriteState containing messages, group_id, and memory_config state: WriteState containing messages, end_user_id, and memory_config
Returns: Returns:
dict: Contains 'write_result' with status and data fields dict: Contains 'write_result' with status and data fields
""" """
messages = state.get('messages', []) messages = state.get('messages', [])
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '') memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write() # Convert LangChain messages to structured format expected by write()
@@ -32,9 +33,7 @@ async def write_node(state: WriteState) -> WriteState:
try: try:
result = await write( result = await write(
messages=structured_messages, messages=structured_messages,
user_id=group_id, end_user_id=end_user_id,
apply_id=group_id,
group_id=group_id,
memory_config=memory_config, memory_config=memory_config,
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -79,7 +79,7 @@ async def make_read_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型 storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关 search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -95,9 +95,9 @@ async def main():
start=time.time() start=time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel): class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式""" """时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容") context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果") end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str): def create_time_retrieval_tool(end_user_id: str):
""" """
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements) 创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
""" """
def clean_temporal_result_fields(data): def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data return data
@tool @tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
""" """
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数: 显式接收参数:
- context: 查询上下文内容 - context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD - start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD - end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID - end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段 - clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d") -end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
""" """
async def _async_search(): async def _async_search():
# 使用传入的参数或默认值 # 使用传入的参数或默认值
actual_group_id = group_id_param or group_id actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索 # 基本时间搜索
results = await search_by_temporal( results = await search_by_temporal(
group_id=actual_group_id, end_user_id=actual_end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=10 limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索 # 关键词时间搜索
results = await search_by_keyword_temporal( results = await search_by_keyword_temporal(
query_text=context, query_text=context,
group_id=group_id, end_user_id=end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=15 limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args: Args:
memory_config: 内存配置对象 memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等 **search_params: 搜索参数,包含end_user_id, limit, include等
""" """
def clean_result_fields(data): def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数 rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序 use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序 use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = { final_params = {
"query_text": context, "query_text": context,
"search_type": search_type, "search_type": search_type,
"group_id": group_id or search_params.get("group_id"), "end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10), "limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件 "output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
clean_output: bool = True clean_output: bool = True
) -> str: ) -> str:
""" """
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段 clean_output: 是否清理输出中的元数据字段
""" """
async def _async_search(): async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context, "context": context,
"search_type": search_type, "search_type": search_type,
"limit": limit, "limit": limit,
"group_id": group_id, "end_user_id": end_user_id,
"clean_output": clean_output "clean_output": clean_output
}) })

View File

@@ -14,6 +14,7 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -26,9 +27,21 @@ async def make_write_graph():
""" """
Create a write graph workflow for memory operations. Create a write graph workflow for memory operations.
The workflow directly processes messages from the initial state Args:
and saves them to Neo4j storage. user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
""" """
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState) workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node) workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j") workflow.add_edge(START, "save_neo4j")
@@ -42,7 +55,7 @@ async def make_write_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "今天周一" message = "今天周一"
group_id = 'new_2025test1103' # 组ID end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话 # 获取数据库会话
@@ -54,9 +67,9 @@ async def main():
) )
try: try:
async with make_write_graph() as graph: async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str, tool_call_id: str,
search_switch: str, search_switch: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter search_switch: Search routing parameter
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
storage_type: Storage type for the workspace (optional) storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = { base_args = {
"usermessages": tool_call_id, "usermessages": tool_call_id,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id "end_user_id": end_user_id
} }
# Always add storage_type and user_rag_memory_id (with defaults if None) # Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
group_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
Args: Args:
group_id: Group identifier for filtering results end_user_id: Group identifier for filtering results
question: Search query text question: Search query text
limit: Maximum number of results to return (default: 5) limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search( answer = await run_hybrid_search(
query_text=cleaned_query, query_text=cleaned_query,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
output_path=output_path, output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}", f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True exc_info=True
) )
# Return empty results on failure # Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
messages: list = None, messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "wyl_20251027",
config_id: str = None config_id: str = None
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier end_user_id: Group identifier
user_id: User identifier
apply_id: Application identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier ref_id: Reference identifier
config_id: Configuration ID for processing config_id: Configuration ID for processing
@@ -58,9 +54,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=ref_id, ref_id=ref_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id config_id=config_id
) )

View File

@@ -13,13 +13,11 @@ class WriteState(TypedDict):
Langgrapg Writing TypedDict Langgrapg Writing TypedDict
''' '''
messages: Annotated[list[AnyMessage], add_messages] messages: Annotated[list[AnyMessage], add_messages]
user_id:str end_user_id: str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object memory_config: object
write_result: dict write_result: dict
data:str data: str
class ReadState(TypedDict): class ReadState(TypedDict):
""" """
@@ -29,7 +27,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加 messages: 消息列表,支持自动追加
loop_count: 遍历次数 loop_count: 遍历次数
search_switch: 搜索类型开关 search_switch: 搜索类型开关
group_id: 组标识 end_user_id: 组标识
config_id: 配置ID用于过滤结果 config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据 data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果 spit_data: 从Split_The_Problem传递的分解结果
@@ -40,7 +38,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int loop_count: int
search_switch: str search_switch: str
group_id: str end_user_id: str
config_id: str config_id: str
data: str # 新增字段用于传递内容 data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果 spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text return text
# 修改后的 save_session 方法 # 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id): def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒 优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": userid, "sessionid": userid,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id, "end_user_id": end_user_id,
"messages": messages, "messages": messages,
"aimessages": aimessages, "aimessages": aimessages,
"starttime": starttime "starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data): def save_sessions_batch(self, sessions_data):
""" """
批量写入多条会话数据,返回 session_id 列表 批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能 优化版本:批量操作,大幅提升性能
""" """
try: try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": session.get('userid'), "sessionid": session.get('userid'),
"apply_id": session.get('apply_id'), "apply_id": session.get('apply_id'),
"group_id": session.get('group_id'), "end_user_id": session.get('end_user_id'),
"messages": session.get('messages'), "messages": session.get('messages'),
"aimessages": session.get('aimessages'), "aimessages": session.get('aimessages'),
"starttime": starttime "starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key) data = self.r.hgetall(key)
return data if data else None return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id): def get_session_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
""" """
result_items = [] result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配 # 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
result_items.append(data) result_items.append(data)
return result_items return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self): def delete_duplicate_sessions(self):
""" """
删除重复会话数据,条件: 删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除 "sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成 优化版本:使用 pipeline 批量操作确保在1秒内完成
""" """
import time import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值 # 获取五个字段的值
sessionid = data.get('sessionid', '') sessionid = data.get('sessionid', '')
user_id = data.get('id', '') user_id = data.get('id', '')
group_id = data.get('group_id', '') end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '') messages = data.get('messages', '')
aimessages = data.get('aimessages', '') aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识 # 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages) identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen: if identifier in seen:
# 重复,标记为待删除 # 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = [] result_items = []
return (result_items) return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id): def find_user_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
""" """
import time import time
start_time = time.time() start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件 # 检查是否符合三个条件
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配 # 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({ matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -29,9 +29,7 @@ logger = get_agent_logger(__name__)
async def write( async def write(
user_id: str, end_user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", ref_id: str = "wyl20251027",
@@ -42,7 +40,7 @@ async def write(
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to "wyl20251027"
@@ -58,7 +56,7 @@ async def write(
logger.info(f"LLM model: {memory_config.llm_model_name}") logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}") logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}") logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}") logger.info(f"end_user_id ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session # Construct clients from memory_config using factory pattern with db session
with get_db_context() as db: with get_db_context() as db:
@@ -83,9 +81,7 @@ async def write(
step_start = time.time() step_start = time.time()
chunked_dialogs = await get_chunked_dialogs( chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy, chunker_strategy=chunker_strategy,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
messages=messages, messages=messages,
ref_id=ref_id, ref_id=ref_id,
config_id=config_id, config_id=config_id,

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。""" """用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
""" """
使用LLM筛选标签列表仅保留具有代表性的核心名词。 使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args: Args:
tags: 原始标签列表 tags: 原始标签列表
group_id: 用户组ID用于获取配置 end_user_id: 用户组ID用于获取配置
Returns: Returns:
筛选后的标签列表 筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config, get_end_user_connected_config,
) )
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if not config_id: if not config_id:
raise ValueError( raise ValueError(
f"No memory_config_id found for group_id: {group_id}. " f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration." "Please ensure the user has a valid memory configuration."
) )
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db( async def get_raw_tags_from_db(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
limit: int, limit: int,
by_user: bool = False by_user: bool = False
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args: Args:
connector: Neo4j连接器实例 connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Returns: Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表 List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else: else:
query = ( query = (
"MATCH (e:ExtractedEntity) " "MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency " "RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC " "ORDER BY frequency DESC "
"LIMIT $limit" "LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询 # 使用项目的Neo4jConnector执行查询
results = await connector.execute_query( results = await connector.execute_query(
query, query,
id=group_id, id=end_user_id,
limit=limit, limit=limit,
names_to_exclude=names_to_exclude names_to_exclude=names_to_exclude
) )
return [(record["name"], record["frequency"]) for record in results] return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
""" """
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。 获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args: Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Raises: Raises:
ValueError: 如果group_id未提供或为空 ValueError: 如果end_user_id未提供或为空
""" """
# 验证group_id必须提供且不为空 # 验证end_user_id必须提供且不为空
if not group_id or not group_id.strip(): if not end_user_id or not end_user_id.strip():
raise ValueError( raise ValueError(
"group_id is required. Please provide a valid group_id or user_id." "end_user_id is required. Please provide a valid end_user_id or user_id."
) )
# 使用项目的Neo4jConnector # 使用项目的Neo4jConnector
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 1. 从数据库获取原始排名靠前的标签 # 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq: if not raw_tags_with_freq:
return [] return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq] raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序 # 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = [] final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id( summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
group_id=user_id, end_user_id=user_id,
limit=limit, limit=limit,
start_date=start_date, start_date=start_date,
end_date=end_date end_date=end_date

View File

@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q WITH $embedding AS q
MATCH (d:Dialogue) MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id) AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v WITH d, q, d.dialog_embedding AS v
WITH d, WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
@@ -50,7 +50,7 @@ WITH d,
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold WHERE score > $threshold
RETURN d.id AS dialog_id, RETURN d.id AS dialog_id,
d.group_id AS group_id, d.end_user_id AS end_user_id,
d.content AS content, d.content AS content,
d.created_at AS created_at, d.created_at AS created_at,
d.expired_at AS expired_at, d.expired_at AS expired_at,

View File

@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def ingest_contexts_via_full_pipeline( async def ingest_contexts_via_full_pipeline(
contexts: List[str], contexts: List[str],
group_id: str, end_user_id: str,
chunker_strategy: str | None = None, chunker_strategy: str | None = None,
embedding_name: str | None = None, embedding_name: str | None = None,
save_chunk_output: bool = False, save_chunk_output: bool = False,
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
This function mirrors the steps in main(), but starts from raw text contexts. This function mirrors the steps in main(), but starts from raw text contexts.
Args: Args:
contexts: List of dialogue texts, each containing lines like "role: message". contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes. end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
dialog = DialogData( dialog = DialogData(
context=context_model, context=context_model,
ref_id=f"pipeline_item_{idx}", ref_id=f"pipeline_item_{idx}",
group_id=group_id, end_user_id=end_user_id,
user_id="default_user", user_id="default_user",
apply_id="default_application", apply_id="default_application",
) )
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
print("No contexts provided for processing.") print("No contexts provided for processing.")
return False return False
return await main_from_contexts(contexts, args.context_group_id) return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], group_id: str): async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data.""" """Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===") print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline( success = await ingest_contexts_via_full_pipeline(
contexts=contexts, contexts=contexts,
group_id=group_id, end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY, chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID, embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True save_chunk_output=True

View File

@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import ( from app.core.memory.utils.definitions import (
PROJECT_ROOT, PROJECT_ROOT,
SELECTED_EMBEDDING_ID, SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID, SELECTED_end_user_id,
SELECTED_LLM_ID, SELECTED_LLM_ID,
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark( async def run_locomo_benchmark(
sample_size: int = 20, sample_size: int = 20,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
search_type: str = "hybrid", search_type: str = "hybrid",
search_limit: int = 12, search_limit: int = 12,
context_char_budget: int = 8000, context_char_budget: int = 8000,
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
Args: Args:
sample_size: Number of QA pairs to evaluate (from first conversation) sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None) end_user_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid" search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context context_char_budget: Max characters for context
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
Returns: Returns:
Dictionary with evaluation results including metrics, timing, and samples Dictionary with evaluation results including metrics, timing, and samples
""" """
# Use default group_id if not provided # Use default end_user_id if not provided
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
# Determine data path # Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
print(f"{'='*60}") print(f"{'='*60}")
print("📊 Configuration:") print("📊 Configuration:")
print(f" Sample size: {sample_size}") print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}") print(f" Group ID: {end_user_id}")
print(f" Search type: {search_type}") print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}") print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars") print(f" Context budget: {context_char_budget} chars")
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
# Step 2: Extract conversations and ingest if needed # Step 2: Extract conversations and ingest if needed
if skip_ingest: if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)") print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n") print(f" Group ID: {end_user_id}\n")
else: else:
print("💾 Checking database ingestion...") print("💾 Checking database ingestion...")
try: try:
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
print(f"📝 Extracted {len(conversations)} conversations") print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented) # Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...") print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
success = await ingest_conversations_if_needed( success = await ingest_conversations_if_needed(
conversations=conversations, conversations=conversations,
group_id=group_id, end_user_id=end_user_id,
reset=reset_group reset=reset_group
) )
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
try: try:
retrieved_info = await retrieve_relevant_information( retrieved_info = await retrieve_relevant_information(
question=question, question=question,
group_id=group_id, end_user_id=end_user_id,
search_type=search_type, search_type=search_type,
search_limit=search_limit, search_limit=search_limit,
connector=connector, connector=connector,
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
"sample_size": len(qa_items), "sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_type": search_type, "search_type": search_type,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
@@ -467,7 +467,7 @@ def main():
help="Number of QA pairs to evaluate" help="Number of QA pairs to evaluate"
) )
parser.add_argument( parser.add_argument(
"--group_id", "--end_user_id",
type=str, type=str,
default=None, default=None,
help="Database group ID for retrieval (uses default if not specified)" help="Database group ID for retrieval (uses default if not specified)"
@@ -516,7 +516,7 @@ def main():
# Run benchmark # Run benchmark
result = asyncio.run(run_locomo_benchmark( result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_type=args.search_type, search_type=args.search_type,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,

View File

@@ -556,7 +556,7 @@ async def run_enhanced_evaluation():
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
search_type="hybrid", search_type="hybrid",
group_id="locomo_sk", end_user_id="locomo_sk",
limit=20, limit=20,
include=["statements", "chunks", "entities", "summaries"], include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重 alpha=0.6, # BM25权重

View File

@@ -348,7 +348,7 @@ def select_and_format_information(
async def retrieve_relevant_information( async def retrieve_relevant_information(
question: str, question: str,
group_id: str, end_user_id: str,
search_type: str, search_type: str,
search_limit: int, search_limit: int,
connector: Any, connector: Any,
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
Args: Args:
question: Question to search for question: Question to search for
group_id: Database group ID (identifies which conversation memory to search) end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid" search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance connector: Neo4j connector instance
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit limit=search_limit
) )
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=question, query_text=question,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
output_path=None, output_path=None,
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
async def ingest_conversations_if_needed( async def ingest_conversations_if_needed(
conversations: List[str], conversations: List[str],
group_id: str, end_user_id: str,
reset: bool = False reset: bool = False
) -> bool: ) -> bool:
""" """
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
Args: Args:
conversations: List of raw conversation texts from LoCoMo dataset conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...] Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper) reset: Whether to clear existing data first (not implemented in wrapper)
Returns: Returns:
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
try: try:
success = await ingest_contexts_via_full_pipeline( success = await ingest_contexts_via_full_pipeline(
contexts=conversations, contexts=conversations,
group_id=group_id, end_user_id=end_user_id,
save_chunk_output=True save_chunk_output=True
) )
return success return success

View File

@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
async def run_locomo_eval( async def run_locomo_eval(
sample_size: int = 1, sample_size: int = 1,
group_id: str | None = None, end_user_id: str | None = None,
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变 context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -262,7 +262,7 @@ async def run_locomo_eval(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变 # 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path): if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json") data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
@@ -340,7 +340,7 @@ async def run_locomo_eval(
# 关键修复:强制重新摄入纯净的对话数据 # 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...") print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端 # 使用异步LLM客户端
with get_db_context() as db: with get_db_context() as db:
@@ -405,7 +405,7 @@ async def run_locomo_eval(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=q, query_text=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
) )
@@ -456,7 +456,7 @@ async def run_locomo_eval(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit limit=adjusted_limit
) )
dialogs = search_results.get("dialogues", []) dialogs = search_results.get("dialogues", [])
@@ -486,7 +486,7 @@ async def run_locomo_eval(
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
output_path=None, output_path=None,
@@ -524,7 +524,7 @@ async def run_locomo_eval(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=q, query_text=q,
group_id=group_id, end_user_id=end_user_id,
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -597,7 +597,7 @@ async def run_locomo_eval(
"dialogues": [ "dialogues": [
{ {
"uuid": d.get("uuid", ""), "uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""), "end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0) "score": d.get("score", 0.0)
} }
@@ -795,7 +795,7 @@ async def run_locomo_eval(
}, },
"samples": samples, "samples": samples,
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
@@ -825,7 +825,7 @@ async def run_locomo_eval(
def main(): def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval") parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
@@ -841,7 +841,7 @@ def main():
result = asyncio.run(run_locomo_eval( result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -524,11 +524,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
# 通过别名匹配进行实体关键词检索多token合并 # 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = [] results: List[Dict[str, Any]] = []
try: try:
for tok in tokens: for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows: if rows:
results.extend(rows) results.extend(rows)
except Exception: except Exception:
@@ -548,15 +548,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称 # 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """ _FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
""" """
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids: if not ids:
return [] return []
try: try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or [] return rows or []
except Exception: except Exception:
return [] return []
@@ -566,18 +566,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """ _TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit LIMIT $limit
""" """
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体""" """专门搜索时间相关的实体"""
try: try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH, rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern, date_pattern=date_pattern,
group_id=group_id, end_user_id=end_user_id,
limit=limit) limit=limit)
return rows or [] return rows or []
except Exception: except Exception:
@@ -624,7 +624,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test( async def run_longmemeval_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_3", end_user_id: str = "longmemeval_zh_bak_3",
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -678,13 +678,13 @@ async def run_longmemeval_test(
contexts.extend(selected) contexts.extend(selected)
print(f"📥 摄入 {len(contexts)} 个上下文到数据库") print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
if reset_group_before_ingest and group_id: if reset_group_before_ingest and end_user_id:
try: try:
_tmp_conn = Neo4jConnector() _tmp_conn = Neo4jConnector()
await _tmp_conn.delete_group(group_id) await _tmp_conn.delete_group(end_user_id)
print(f"🧹 已清空组 {group_id} 的历史图数据") print(f"🧹 已清空组 {end_user_id} 的历史图数据")
except Exception as _e: except Exception as _e:
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}") print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
finally: finally:
try: try:
await _tmp_conn.close() await _tmp_conn.close()
@@ -696,7 +696,7 @@ async def run_longmemeval_test(
else: else:
await _ingest_fn( await _ingest_fn(
contexts, contexts,
group_id, end_user_id,
save_chunk_output=save_chunk_output, save_chunk_output=save_chunk_output,
save_chunk_output_path=save_chunk_output_path, save_chunk_output_path=save_chunk_output_path,
) )
@@ -751,7 +751,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -796,7 +796,7 @@ async def run_longmemeval_test(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
chunks = search_results.get("chunks", []) chunks = search_results.get("chunks", [])
@@ -831,7 +831,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
) )
@@ -849,7 +849,7 @@ async def run_longmemeval_test(
kw_res = await search_graph( kw_res = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
if isinstance(kw_res, dict): if isinstance(kw_res, dict):
@@ -860,7 +860,7 @@ async def run_longmemeval_test(
# 时间推理问题的特殊处理 # 时间推理问题的特殊处理
if is_temporal: if is_temporal:
# 专门搜索时间实体 # 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2) time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities: if time_entities:
kw_entities.extend(time_entities) kw_entities.extend(time_entities)
# 添加时间相关关键词检索 # 添加时间相关关键词检索
@@ -870,7 +870,7 @@ async def run_longmemeval_test(
time_res = await search_graph( time_res = await search_graph(
connector=connector, connector=connector,
q=tk, q=tk,
group_id=group_id, end_user_id=end_user_id,
limit=2, limit=2,
) )
if isinstance(time_res, dict): if isinstance(time_res, dict):
@@ -881,7 +881,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配 # 中文关键词拆分后做别名匹配
cn_tokens = _extract_cn_tokens(question) cn_tokens = _extract_cn_tokens(question)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities: if alias_entities:
kw_entities.extend(alias_entities) kw_entities.extend(alias_entities)
@@ -895,7 +895,7 @@ async def run_longmemeval_test(
except Exception: except Exception:
pass pass
if ids: if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id) id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities: if id_entities:
kw_entities.extend(id_entities) kw_entities.extend(id_entities)
@@ -909,7 +909,7 @@ async def run_longmemeval_test(
sub_res = await search_graph( sub_res = await search_graph(
connector=connector, connector=connector,
q=str(kw), q=str(kw),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(sub_res, dict): if isinstance(sub_res, dict):
@@ -928,7 +928,7 @@ async def run_longmemeval_test(
opt_res = await search_graph( opt_res = await search_graph(
connector=connector, connector=connector,
q=str(opt), q=str(opt),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(opt_res, dict): if isinstance(opt_res, dict):
@@ -1010,7 +1010,7 @@ async def run_longmemeval_test(
kw_fallback = await search_graph( kw_fallback = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=max(search_limit, 5), limit=max(search_limit, 5),
) )
fb_dialogs = kw_fallback.get("dialogues", []) or [] fb_dialogs = kw_fallback.get("dialogues", []) or []
@@ -1224,7 +1224,7 @@ async def run_longmemeval_test(
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
}, },
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
@@ -1307,7 +1307,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_longmemeval_test( run_longmemeval_test(
sample_size=sample_size, sample_size=sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
# 通过别名匹配进行实体关键词检索多token合并 # 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = [] results: List[Dict[str, Any]] = []
try: try:
for tok in tokens: for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows: if rows:
results.extend(rows) results.extend(rows)
except Exception: except Exception:
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称 # 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """ _FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
""" """
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids: if not ids:
return [] return []
try: try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or [] return rows or []
except Exception: except Exception:
return [] return []
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """ _TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit LIMIT $limit
""" """
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体""" """专门搜索时间相关的实体"""
try: try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH, rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern, date_pattern=date_pattern,
group_id=group_id, end_user_id=end_user_id,
limit=limit) limit=limit)
return rows or [] return rows or []
except Exception: except Exception:
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
# 技术术语专门检索 # 技术术语专门检索
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
"""专门搜索技术术语相关的实体""" """专门搜索技术术语相关的实体"""
tech_entities = [] tech_entities = []
try: try:
# GPS相关 # GPS相关
if any(term in question for term in ["GPS", "导航", "定位系统"]): if any(term in question for term in ["GPS", "导航", "定位系统"]):
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit) gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
if gps_rows: if gps_rows:
tech_entities.extend(gps_rows) tech_entities.extend(gps_rows)
# 活动相关 # 活动相关
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit) workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
if workshop_rows: if workshop_rows:
tech_entities.extend(workshop_rows) tech_entities.extend(workshop_rows)
# 时间顺序相关 # 时间顺序相关
if any(term in question for term in ["", "", "第一个"]): if any(term in question for term in ["", "", "第一个"]):
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit) time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
if time_rows: if time_rows:
tech_entities.extend(time_rows) tech_entities.extend(time_rows)
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test( async def run_longmemeval_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_2", end_user_id: str = "longmemeval_zh_bak_2",
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
) )
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
search_results = await search_graph( search_results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
dialogs = search_results.get("dialogues", []) dialogs = search_results.get("dialogues", [])
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
) )
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
kw_res = await search_graph( kw_res = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
) )
if isinstance(kw_res, dict): if isinstance(kw_res, dict):
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
kw_entities = kw_res.get("entities", []) or [] kw_entities = kw_res.get("entities", []) or []
# 技术术语专门检索 # 技术术语专门检索
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2) tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
if tech_entities: if tech_entities:
kw_entities.extend(tech_entities) kw_entities.extend(tech_entities)
# 时间推理问题的特殊处理 # 时间推理问题的特殊处理
if is_temporal: if is_temporal:
# 专门搜索时间实体 # 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2) time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities: if time_entities:
kw_entities.extend(time_entities) kw_entities.extend(time_entities)
# 添加时间相关关键词检索 # 添加时间相关关键词检索
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
time_res = await search_graph( time_res = await search_graph(
connector=connector, connector=connector,
q=tk, q=tk,
group_id=group_id, end_user_id=end_user_id,
limit=2, limit=2,
) )
if isinstance(time_res, dict): if isinstance(time_res, dict):
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配 # 中文关键词拆分后做别名匹配
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities: if alias_entities:
kw_entities.extend(alias_entities) kw_entities.extend(alias_entities)
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
except Exception: except Exception:
pass pass
if ids: if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id) id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities: if id_entities:
kw_entities.extend(id_entities) kw_entities.extend(id_entities)
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
sub_res = await search_graph( sub_res = await search_graph(
connector=connector, connector=connector,
q=str(kw), q=str(kw),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(sub_res, dict): if isinstance(sub_res, dict):
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
opt_res = await search_graph( opt_res = await search_graph(
connector=connector, connector=connector,
q=str(opt), q=str(opt),
group_id=group_id, end_user_id=end_user_id,
limit=max(3, search_limit // 2), limit=max(3, search_limit // 2),
) )
if isinstance(opt_res, dict): if isinstance(opt_res, dict):
@@ -971,7 +971,7 @@ async def run_longmemeval_test(
kw_fallback = await search_graph( kw_fallback = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=max(search_limit, 5), limit=max(search_limit, 5),
) )
fb_dialogs = kw_fallback.get("dialogues", []) or [] fb_dialogs = kw_fallback.get("dialogues", []) or []
@@ -1199,7 +1199,7 @@ async def run_longmemeval_test(
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
}, },
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
@@ -1278,7 +1278,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_longmemeval_test( run_longmemeval_test(
sample_size=sample_size, sample_size=sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_GROUP_ID
# Load data # Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path): if not os.path.exists(data_path):
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 # 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items] contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id) await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用) # LLM client (使用异步调用)
with get_db_context() as db: with get_db_context() as db:
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
results = await run_hybrid_search( results = await run_hybrid_search(
query_text=question, query_text=question,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
output_path=None, output_path=None,
@@ -298,7 +298,7 @@ def main():
load_dotenv() load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json") parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
@@ -309,7 +309,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_memsciqa_eval( run_memsciqa_eval(
sample_size=args.sample_size, sample_size=args.sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -199,7 +199,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
async def run_memsciqa_test( async def run_memsciqa_test(
sample_size: int = 3, sample_size: int = 3,
group_id: str | None = None, end_user_id: str | None = None,
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -217,7 +217,7 @@ async def run_memsciqa_test(
""" """
# 默认使用指定的 memsci 组 ID # 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci" end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底) # 数据路径解析(项目根与当前工作目录兜底)
if not data_path: if not data_path:
@@ -283,7 +283,7 @@ async def run_memsciqa_test(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=question, query_text=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
) )
@@ -292,7 +292,7 @@ async def run_memsciqa_test(
results = await search_graph( results = await search_graph(
connector=connector, connector=connector,
q=question, q=question,
group_id=group_id, end_user_id=end_user_id,
limit=search_limit, limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
) )
@@ -500,7 +500,7 @@ async def run_memsciqa_test(
}, },
"samples": samples, "samples": samples,
"params": { "params": {
"group_id": group_id, "end_user_id": end_user_id,
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"llm_temperature": llm_temperature, "llm_temperature": llm_temperature,
@@ -543,7 +543,7 @@ def main():
result = asyncio.run( result = asyncio.run(
run_memsciqa_test( run_memsciqa_test(
sample_size=sample_size, sample_size=sample_size,
group_id=args.group_id, end_user_id=args.end_user_id,
search_limit=args.search_limit, search_limit=args.search_limit,
context_char_budget=args.context_char_budget, context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,

View File

@@ -26,7 +26,7 @@ async def run(
dataset: str, dataset: str,
sample_size: int, sample_size: int,
reset_group: bool, reset_group: bool,
group_id: str | None, end_user_id: str | None,
judge_model: str | None = None, judge_model: str | None = None,
search_limit: int | None = None, search_limit: int | None = None,
context_char_budget: int | None = None, context_char_budget: int | None = None,
@@ -37,17 +37,17 @@ async def run(
max_contexts_per_item: int | None = None, max_contexts_per_item: int | None = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_GROUP_ID
if reset_group: if reset_group:
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
await connector.delete_group(group_id) await connector.delete_group(end_user_id)
finally: finally:
await connector.close() await connector.close()
if dataset == "locomo": if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -61,7 +61,7 @@ async def run(
return await run_locomo_eval(**kwargs) return await run_locomo_eval(**kwargs)
if dataset == "memsciqa": if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -75,7 +75,7 @@ async def run(
return await run_memsciqa_eval(**kwargs) return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval": if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None: if search_limit is not None:
kwargs["search_limit"] = search_limit kwargs["search_limit"] = search_limit
if context_char_budget is not None: if context_char_budget is not None:
@@ -99,8 +99,8 @@ def main():
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo") parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据") parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json") parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名") parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
@@ -117,7 +117,7 @@ def main():
args.dataset, args.dataset,
args.sample_size, args.sample_size,
args.reset_group, args.reset_group,
args.group_id, args.end_user_id,
args.judge_model, args.judge_model,
args.search_limit, args.search_limit,
args.context_char_budget, args.context_char_budget,

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph. """Parameters for temporal search queries in the knowledge graph.
Attributes: Attributes:
group_id: Group ID to filter search results (default: 'test') end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results apply_id: Application ID to filter search results
user_id: User ID to filter search results user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3) limit: Maximum number of results to return (default: 3)
""" """
group_id: Optional[str] = Field("test", description="The group ID to filter the search.") end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.") user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.") start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge id: Unique identifier for the edge
source: ID of the source node source: ID of the source node
target: ID of the target node target: ID of the target node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this edge run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective) created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective) expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.") source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.") target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes: Attributes:
id: Unique identifier for the node id: Unique identifier for the node
name: Name of the node name: Name of the node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this node run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective) created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective) expired_at: Optional timestamp when the node expires (system perspective)
""" """
id: str = Field(..., description="The unique identifier for the node.") id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.") name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.") end_user_id: str = Field(..., description="The end user ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")

View File

@@ -55,7 +55,7 @@ class Statement(BaseModel):
Attributes: Attributes:
id: Unique identifier for the statement id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement statement_embedding: Optional embedding vector for the statement
@@ -73,7 +73,7 @@ class Statement(BaseModel):
""" """
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.") statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
created_at: Timestamp when the dialog was created created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future) expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs metadata: Additional metadata as key-value pairs
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.") context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data") end_user_id: str = Field(default=..., description="End user ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
return [] return []
def assign_group_id_to_statements(self) -> None: def assign_group_id_to_statements(self) -> None:
"""Assign this dialog's group_id to all statements in all chunks. """Assign this dialog's end_user_id to all statements in all chunks.
This method updates statements that don't have a group_id set. This method updates statements that don't have a end_user_id set.
""" """
for chunk in self.chunks: for chunk in self.chunks:
for statement in chunk.statements: for statement in chunk.statements:
if statement.group_id is None: if statement.end_user_id is None:
statement.group_id = self.group_id statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
query_text: The search query text query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid) search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering end_user_id: Group identifier for filtering
limit: Maximum number of results limit: Maximum number of results
include: List of result types to include include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger # Log using the standard logger
logger.info( logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, " f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}" f"end_user_id={end_user_id}, limit={limit}, include={include}"
) )
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
@@ -715,7 +716,7 @@ async def run_hybrid_search(
} }
# Log the search query # Log the search query
log_search_query(query_text, search_type, group_id, limit, include) log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector() connector = Neo4jConnector()
results = {} results = {}
@@ -732,7 +733,7 @@ async def run_hybrid_search(
search_graph( search_graph(
connector=connector, connector=connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include include=include
) )
@@ -769,7 +770,7 @@ async def run_hybrid_search(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
) )
@@ -916,9 +917,7 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -929,7 +928,7 @@ async def search_by_temporal(
Temporal search across Statements. Temporal search across Statements.
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id - Optionally filters by end_user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
@@ -939,9 +938,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date) end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -950,9 +947,7 @@ async def search_by_temporal(
}) })
statements = await search_graph_by_temporal( statements = await search_graph_by_temporal(
connector=connector, connector=connector,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -964,9 +959,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date) invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal( statements = await search_graph_by_keyword_temporal(
connector=connector, connector=connector,
query_text=query_text, query_text=query_text,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id( chunks = await search_graph_by_chunk_id(
connector=connector, connector=connector,
chunk_id=chunk_id, chunk_id=chunk_id,
group_id=group_id, end_user_id=end_user_id,
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值 # 获取end_user_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages) context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
""" """
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map) 返回: (deduped_entities, id_redirect, exact_merge_map)
""" """
exact_merge_map: Dict[str, Dict] = {} exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes: for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip() name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip() type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界 # 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map: if key not in canonical_map:
canonical_map[key] = ent canonical_map[key] = ent
id_redirect[ent.id] = ent.id id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用) # 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try: try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map: if k not in exact_merge_map:
exact_merge_map[k] = { exact_merge_map[k] = {
"canonical_id": canonical.id, "canonical_id": canonical.id,
"group_id": canonical.group_id, "end_user_id": canonical.end_user_id,
"name": canonical.name, "name": canonical.name,
"entity_type": canonical.entity_type, "entity_type": canonical.entity_type,
"merged_ids": set(), "merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j] b = deduped_entities[j]
# 跳过不同业务组的实体 # 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1 j += 1
continue continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append( fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | " f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}" f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
) )
except Exception: except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志 # 记录 LLM 融合日志
try: try:
llm_records.append( llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception: except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id id_redirect[k] = a.id
try: try:
disamb_records.append( disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
except Exception: except Exception:
pass pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系 # 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception: except Exception:
pass pass
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim, "name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i] a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)): for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j] b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复 # 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
# 规则2类型必须兼容调用_simple_type_ok判断 # 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3) - max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition - shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns: Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds - global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
""" """
group_id 分块,避免跨组实体在同一块,减少无效候选对 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args: Args:
nodes: 实体节点列表 nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
""" """
groups: Dict[str, List[ExtractedEntityNode]] = {} groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes: for e in nodes:
gid = getattr(e, "group_id", None) gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e) groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = [] blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items(): for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons # Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量 # 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes) current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理 # 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes) blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环 if not blocks: # 无块可处理(实体已全部折叠),退出循环
break break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i] a = entity_nodes[i]
b = entity_nodes[j] b = entity_nodes[j]
# 必须同组 # 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
ta = getattr(a, "entity_type", None) ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None) tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode( return ExtractedEntityNode(
id=row.get("id"), id=row.get("id"),
name=row.get("name") or "", name=row.get("name") or "",
group_id=row.get("group_id") or "", end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "", user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "", apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")), created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重 end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
""" """
第二层去重消歧: 第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 - 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID - 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
""" """
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
] ]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成 candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id, connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动 use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
) )

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None: if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略 # 先探测 end_user_id决定报告写入策略
group_id: Optional[str] = None end_user_id: Optional[str] = None
for dd in dialog_data_list: for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None) end_user_id = getattr(dd, "end_user_id", None)
if group_id: if end_user_id:
break break
# 第一层去重消歧 # 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合 # 第二层去重消歧:与 Neo4j 中同组实体联合融合
try: try:
if group_id: if end_user_id:
if connector: if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector, connector=connector,
group_id=group_id, end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes, entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges, statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges, entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else: else:
print("Skip second-layer dedup: missing connector") print("Skip second-layer dedup: missing connector")
else: else:
print("Skip second-layer dedup: missing group_id") print("Skip second-layer dedup: missing end_user_id")
except Exception as e: except Exception as e:
print(f"Second-layer dedup failed: {e}") print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list): for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks): for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content)) all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx)) chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块 # 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index): async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data chunk, end_user_id, dialogue_content = chunk_data
try: try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度 # 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id config_id = dialog_data_list[0].config_id
# 加载DataConfig # 加载MemoryConfig
data_config = None memory_config = None
if config_id: if config_id:
try: try:
from app.db import SessionLocal from app.db import SessionLocal
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
db = SessionLocal() db = SessionLocal()
try: try:
data_config = DataConfigRepository.get_by_id(db, config_id) memory_config = MemoryConfigRepository.get_by_id(db, config_id)
finally: finally:
db.close() db.close()
if data_config and not data_config.emotion_enabled: if memory_config and not memory_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取") logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
except Exception as e: except Exception as e:
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
else: else:
logger.info("未找到config_id跳过情绪提取") logger.info("未找到config_id跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
# 如果配置未启用情绪提取,直接返回空映射 # 如果配置未启用情绪提取,直接返回空映射
if not data_config or not data_config.emotion_enabled: if not memory_config or not memory_config.emotion_enabled:
logger.info("情绪提取未启用,跳过") logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
total_statements += 1 total_statements += 1
# 只处理用户的陈述句 (role 为 "user") # 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user": if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config)) all_statements.append((statement, memory_config))
statement_metadata.append((d_idx, statement.id)) statement_metadata.append((d_idx, statement.id))
filtered_statements += 1 filtered_statements += 1
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
# 初始化情绪提取服务 # 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService from app.services.emotion_extraction_service import EmotionExtractionService
emotion_service = EmotionExtractionService( emotion_service = EmotionExtractionService(
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
) )
# 全局并行处理所有陈述句 # 全局并行处理所有陈述句
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id, id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id, ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "", content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id, id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id, dialog_id=dialog_data.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content, content=chunk.content,
chunk_embedding=chunk.chunk_embedding, chunk_embedding=chunk.chunk_embedding,
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge( statement_chunk_edge = StatementChunkEdge(
source=statement.id, source=statement.id,
target=chunk.id, target=chunk.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1095,9 +1087,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1112,9 +1102,7 @@ class ExtractionOrchestrator:
source=statement.id, source=statement.id,
target=entity.id, target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1134,9 +1122,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate, relation_type=triplet.predicate,
statement=statement.statement, statement=statement.statement,
source_statement_id=statement.id, source_statement_id=statement.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1763,14 +1749,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
Args: Args:
chunker_strategy: 分块策略(默认: RecursiveChunker chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID end_user_id: 组ID
indices: 要处理的数据索引列表(可选) indices: 要处理的数据索引列表(可选)
Returns: Returns:
@@ -1834,7 +1820,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=data['id'], ref_id=data['id'],
group_id=group_id, end_user_id=end_user_id,
metadata=dialog_metadata, metadata=dialog_metadata,
) )
@@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
@@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args: Args:
chunker_strategy: 分块策略 chunker_strategy: 分块策略
group_id: 组ID end_user_id: 组ID
user_id: 用户ID user_id: 用户ID
apply_id: 应用ID apply_id: 应用ID
indices: 要处理的数据索引列表 indices: 要处理的数据索引列表
@@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices, indices=indices,
) )
# 设置 group_id, user_id, apply_id # 设置 end_user_id
for dd in preprocessed_data: for dd in preprocessed_data:
dd.group_id = group_id dd.end_user_id = end_user_id
dd.user_id = user_id
dd.apply_id = apply_id
# 步骤2: 语义剪枝 # 步骤2: 语义剪枝
try: try:

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode( node = MemorySummaryNode(
id=uuid4().hex, id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}", name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id, end_user_id=dialog.end_user_id,
user_id=dialog.user_id, user_id=dialog.end_user_id,
apply_id=dialog.apply_id, apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(), created_at=datetime.now(),
expired_at=datetime(9999, 12, 31), expired_at=datetime(9999, 12, 31),

View File

@@ -82,12 +82,12 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements """Process a single chunk and return extracted statements
Args: Args:
chunk: Chunk object to process chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context dialogue_content: Full dialogue content to provide as context
Returns: Returns:
@@ -158,7 +158,7 @@ class StatementExtractor:
temporal_info=temporal_type, temporal_info=temporal_type,
relevence_info=relevence_info, relevence_info=relevence_info,
chunk_id=chunk.id, chunk_id=chunk.id,
group_id=group_id, end_user_id=end_user_id,
speaker=chunk_speaker, speaker=chunk_speaker,
) )
@@ -184,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data # Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather( results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], *[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True return_exceptions=True
) )
@@ -225,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1): for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n") f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n") f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n") f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n") f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n") f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n") f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -298,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({ dialog_sections.append({
"dialog_id": dialog.ref_id, "dialog_id": dialog.ref_id,
"group_id": dialog.group_id, "end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "", "content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations, "strong": strong_relations,
"weak": weak_relations, "weak": weak_relations,
@@ -312,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1): for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n") f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n") f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n") f.write("Content:\n")
f.write(f"{section.get('content', '')}\n") f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n") f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("") prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info( prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try: try:
prompt_logger.info( prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤 end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间) current_time: 当前时间(可选,默认使用系统时间)
Returns: Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
# 步骤1读取当前节点状态 # 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
raise ValueError( raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
logger.info( logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self, self,
node_ids: List[str], node_ids: List[str],
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args: Args:
node_ids: 节点ID列表 node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型) node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选 end_user_id: 组ID可选
current_time: 当前时间(可选) current_time: 当前时间(可选)
Returns: Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access( task = self.record_access(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id, end_user_id=end_user_id,
current_time=current_time current_time=current_time
) )
tasks.append(task) tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]: ) -> Tuple[ConsistencyCheckResult, Optional[str]]:
""" """
检查节点数据的一致性 检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Tuple[ConsistencyCheckResult, Optional[str]]: Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举 - 一致性检查结果枚举
- 错误描述(如果不一致) - 错误描述(如果不一致)
""" """
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
return ConsistencyCheckResult.CONSISTENT, None return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency( async def check_batch_consistency(
self, self,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1000 limit: int = 1000
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args: Args:
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
limit: 检查的最大节点数 limit: 检查的最大节点数
Returns: Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label}) MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL WHERE n.access_history IS NOT NULL
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id RETURN n.id as id
LIMIT $limit LIMIT $limit
""" """
params = {"limit": limit} params = {"limit": limit}
if group_id: if end_user_id:
params["group_id"] = group_id params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results] node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> bool: ) -> bool:
""" """
自动修复节点的数据不一致问题 自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
bool: 修复成功返回True否则返回False bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True return True
# 获取节点数据 # 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
SET n += $repair_data SET n += $repair_data
RETURN n RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id, 'node_id': node_id,
'repair_data': repair_data 'repair_data': repair_data
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params) await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
获取节点数据 获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id, RETURN n.id as id,
n.importance_score as importance_score, n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
""" """
params = {'node_id': node_id} params = {'node_id': node_id}
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str, node_id: str,
node_label: str, node_label: str,
update_data: Dict[str, Any], update_data: Dict[str, Any],
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
原子性更新节点(使用乐观锁) 原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
update_data: 更新数据 update_data: 更新数据
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Dict[str, Any]: 更新后的节点数据 Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突 RuntimeError: 如果更新失败或发生版本冲突
""" """
# 定义事务函数 # 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id): async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号 # 步骤1读取当前节点并获取版本号
read_query = f""" read_query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
read_query += " WHERE n.group_id = $group_id" read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """ read_query += """
RETURN n.id as id, RETURN n.id as id,
n.version as version, n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
""" """
read_params = {'node_id': node_id} read_params = {'node_id': node_id}
if group_id: if end_user_id:
read_params['group_id'] = group_id read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params) read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single() current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句 # 构建 WHERE 子句
where_conditions = [] where_conditions = []
if group_id: if end_user_id:
where_conditions.append("n.group_id = $group_id") where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查 # 添加版本检查
if current_version > 0: if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'], 'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count'] 'access_count': update_data['access_count']
} }
if group_id: if end_user_id:
update_params['group_id'] = group_id update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params) update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single() updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
return result return result
except Exception as e: except Exception as e:

View File

@@ -11,9 +11,10 @@ Functions:
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
def load_actr_config_from_db( def load_actr_config_from_db(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
从数据库加载 ACT-R 配置参数 从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数, 从 PostgreSQL 的 memory_config 表读取配置参数,
并计算派生参数(如 forgetting_rate 并计算派生参数(如 forgetting_rate
Args: Args:
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
# 从数据库加载配置 # 从数据库加载配置
try: try:
repository = DataConfigRepository() repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id) db_config = repository.get_by_id(db, config_id)
if db_config is None: if db_config is None:
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
def create_actr_calculator_from_config( def create_actr_calculator_from_config(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> ACTRCalculator: ) -> ACTRCalculator:
""" """
从数据库配置创建 ACTRCalculator 实例 从数据库配置创建 ACTRCalculator 实例
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
ValueError: 如果指定的 config_id 不存在 ValueError: 如果指定的 config_id 不存在
Examples: Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
""" """
# 加载配置 # 加载配置
config = load_actr_config_from_db(db, config_id) config = load_actr_config_from_db(db, config_id)

View File

@@ -16,6 +16,7 @@ Classes:
import logging import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from uuid import UUID
from datetime import datetime from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
@@ -66,10 +67,10 @@ class ForgettingScheduler:
async def run_forgetting_cycle( async def run_forgetting_cycle(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100, max_merge_batch_size: int = 100,
min_days_since_access: int = 30, min_days_since_access: int = 30,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -77,7 +78,7 @@ class ForgettingScheduler:
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100 max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +108,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat() start_time_iso = start_time.isoformat()
logger.info( logger.info(
f"开始遗忘周期: group_id={group_id}, " f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, " f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}" f"min_days={min_days_since_access}"
) )
try: try:
# 步骤1统计遗忘前的节点数量 # 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id) nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}") logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对 # 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id, end_user_id=end_user_id,
min_days_since_access=min_days_since_access min_days_since_access=min_days_since_access
) )
@@ -213,7 +214,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'], 'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'], 'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'], 'statement_importance': pair['statement_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
entity_node = { entity_node = {
@@ -222,7 +223,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'], 'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'], 'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'], 'entity_importance': pair['entity_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
# 融合节点 # 融合节点
@@ -262,7 +263,7 @@ class ForgettingScheduler:
continue continue
# 步骤6统计遗忘后的节点数量 # 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id) nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}") logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告 # 步骤7生成遗忘报告
@@ -315,7 +316,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes( async def _count_knowledge_nodes(
self, self,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> int: ) -> int:
""" """
统计知识层节点总数 统计知识层节点总数
@@ -323,7 +324,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
Returns: Returns:
int: 知识层节点总数 int: 知识层节点总数
@@ -333,16 +334,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN count(n) as total RETURN count(n) as total
""" """
params = {} params = {}
if group_id: if end_user_id:
params['group_id'] = group_id end_user_id['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)

View File

@@ -13,6 +13,7 @@ Classes:
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -90,7 +91,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes( async def find_forgettable_nodes(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
min_days_since_access: int = 30 min_days_since_access: int = 30
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -102,7 +103,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边 3. Statement 和 Entity 之间存在关系边
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
Returns: Returns:
@@ -136,8 +137,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person') AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
""" """
if group_id: if end_user_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id" query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """ query += """
RETURN s.id as statement_id, RETURN s.id as statement_id,
@@ -159,8 +160,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold, 'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso 'cutoff_time': cutoff_time_iso
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -176,7 +177,7 @@ class ForgettingStrategy:
self, self,
statement_node: Dict[str, Any], statement_node: Dict[str, Any],
entity_node: Dict[str, Any], entity_node: Dict[str, Any],
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -247,8 +248,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation'] entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance'] entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点) # 获取 end_user_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id') end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容 # 生成摘要内容
summary_text = await self._generate_summary( summary_text = await self._generate_summary(
@@ -325,7 +326,7 @@ class ForgettingStrategy:
last_access_time: $current_time, last_access_time: $current_time,
access_count: 1, access_count: 1,
version: 1, version: 1,
group_id: $group_id, end_user_id: $end_user_id,
created_at: datetime($current_time), created_at: datetime($current_time),
merged_at: datetime($current_time) merged_at: datetime($current_time)
}) })
@@ -423,7 +424,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation, 'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance, 'inherited_importance': inherited_importance,
'current_time': current_time_iso, 'current_time': current_time_iso,
'group_id': group_id 'end_user_id': end_user_id
} }
try: try:
@@ -462,7 +463,7 @@ class ForgettingStrategy:
statement_text: str, statement_text: str,
entity_name: str, entity_name: str,
entity_type: str, entity_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -527,7 +528,7 @@ class ForgettingStrategy:
statement_text, entity_name, entity_type statement_text, entity_name, entity_type
) )
async def _get_llm_client(self, db, config_id: int): async def _get_llm_client(self, db, config_id: UUID):
""" """
从数据库获取 LLM 客户端 从数据库获取 LLM 客户端
@@ -539,11 +540,11 @@ class ForgettingStrategy:
LLM 客户端实例,如果无法获取则返回 None LLM 客户端实例,如果无法获取则返回 None
""" """
try: try:
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置 # 从数据库读取配置
repository = DataConfigRepository() repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id) db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None: if db_config is None or db_config.llm_id is None:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str = "hybrid", search_type: str = "hybrid",
group_id: str | None = None, end_user_id: str | None = None,
apply_id: str | None = None, apply_id: str | None = None,
user_id: str | None = None, user_id: str | None = None,
limit: int = 50, limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic" search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤 end_user_id: 组ID过滤
apply_id: 应用ID过滤 apply_id: 应用ID过滤
user_id: 用户ID过滤 user_id: 用户ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索 # 执行搜索
result = await strategy.search( result = await strategy.search(
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
alpha=alpha, alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search( # async def search(
# self, # self,
# query_text: str, # query_text: str,
# group_id: Optional[str] = None, # end_user_id: Optional[str] = None,
# limit: int = 50, # limit: int = 50,
# include: Optional[List[str]] = None, # include: Optional[List[str]] = None,
# **kwargs # **kwargs
@@ -86,7 +86,7 @@
# Args: # Args:
# query_text: 查询文本 # query_text: 查询文本
# group_id: 可选的组ID过滤 # end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数 # limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表 # include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve # **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns: # Returns:
# SearchResult: 搜索结果对象 # SearchResult: 搜索结果对象
# """ # """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") # logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数 # # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha) # alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索 # # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search( # keyword_result = await self.keyword_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
# semantic_result = await self.semantic_strategy.search( # semantic_result = await self.semantic_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata( # metadata = self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list, # include=include_list,
# alpha=alpha, # alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata( # metadata=self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# error=str(e) # error=str(e)
# ) # )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph( results_dict = await search_graph(
connector=self.connector, connector=self.connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self, self,
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
**kwargs **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型 search_type: 搜索类型
group_id: 组ID end_user_id: 组ID
limit: 结果限制 limit: 结果限制
**kwargs: 其他元数据 **kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = { metadata = {
"query": query_text, "query": query_text,
"search_type": search_type, "search_type": search_type,
"group_id": group_id, "end_user_id": end_user_id,
"limit": limit, "limit": limit,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
} }

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector, connector=self.connector,
embedder_client=self.embedder_client, embedder_client=self.embedder_client,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [ target_keys = [
"id", "id",
"statement", "statement",
"group_id", "end_user_id",
"chunk_id", "chunk_id",
"created_at", "created_at",
"expired_at", "expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
""" """
EXCLUDE_FIELDS = { EXCLUDE_FIELDS = {
"user_id", "user_id",
"group_id", "end_user_id",
"entity_type", "entity_type",
"connect_strength", "connect_strength",
"relationship_type", "relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self, self,
config_id: str, config_id: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
success: bool = True, success: bool = True,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None
): ):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args: Args:
config_id: 配置 ID config_id: 配置 ID
user_id: 用户 ID可选 user_id: 用户 ID可选
group_id: 组 ID可选 end_user_id: 组 ID可选
success: 是否成功 success: 是否成功
details: 详细信息(可选) details: 详细信息(可选)
""" """
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"CONFIG_LOAD config_id={config_id} " f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} " f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}" f"result={result}"
) )
if details: if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self, self,
operation: str, operation: str,
config_id: str, config_id: str,
group_id: str, end_user_id: str,
success: bool = True, success: bool = True,
duration: Optional[float] = None, duration: Optional[float] = None,
error: Optional[str] = None, error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args: Args:
operation: 操作类型WRITE, READ 等) operation: 操作类型WRITE, READ 等)
config_id: 配置 ID config_id: 配置 ID
group_id: 组 ID end_user_id: 组 ID
success: 是否成功 success: 是否成功
duration: 操作耗时(秒) duration: 操作耗时(秒)
error: 错误信息(可选) error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"{operation.upper()} config_id={config_id} " f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}" f"group={end_user_id} result={result}"
) )
if duration is not None: if duration is not None:
msg += f" duration={duration:.2f}s" msg += f" duration={duration:.2f}s"

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum): class Field(StrEnum):
CONTENT_KEY = "page_content" CONTENT_KEY = "page_content"
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "end_user_id"
VECTOR = auto() VECTOR = auto()
# Sparse Vector aims to support full text search # Sparse Vector aims to support full text search
SPARSE_VECTOR = auto() SPARSE_VECTOR = auto()

View File

@@ -26,7 +26,7 @@ logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str, def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]: config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID.""" """Parse model ID from string or UUID."""
if model_id is None: if model_id is None:
return None return None
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
model_type: str, model_type: str,
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Validate that a model exists and is active. """Validate that a model exists and is active.
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
required: bool = False, required: bool = False,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]: ) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status. """Validate and resolve a model ID, checking existence and active status.
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
def validate_embedding_model( def validate_embedding_model(
config_id: int, config_id: UUID,
embedding_id: Union[str, UUID, None], embedding_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
@@ -256,7 +256,7 @@ def validate_embedding_model(
def validate_llm_model( def validate_llm_model(
config_id: int, config_id: UUID,
llm_id: Union[str, UUID, None], llm_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from uuid import UUID
from pydantic import Field from pydantic import Field
from typing import Literal from typing import Literal
@@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )
@@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )

View File

@@ -22,7 +22,7 @@ class MemoryReadNode(BaseNode):
raise RuntimeError("End user id is required") raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
group_id=end_user_id, end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state), message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,

View File

@@ -18,7 +18,7 @@ from .appshare_model import AppShare
from .release_share_model import ReleaseShare from .release_share_model import ReleaseShare
from .conversation_model import Conversation, Message from .conversation_model import Conversation, Message
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
from .data_config_model import DataConfig from .memory_config_model import MemoryConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo from .retrieval_info import RetrievalInfo
@@ -57,7 +57,7 @@ __all__ = [
"ApiKey", "ApiKey",
"ApiKeyLog", "ApiKeyLog",
"ApiKeyType", "ApiKeyType",
"DataConfig", "MemoryConfig",
"MultiAgentConfig", "MultiAgentConfig",
"AgentInvocation", "AgentInvocation",
"WorkflowConfig", "WorkflowConfig",

View File

@@ -1,88 +0,0 @@
import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class DataConfig(Base):
"""数据配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "data_config"
# 主键
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")
# 组织信息
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
group_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<DataConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -1,39 +1,88 @@
# -*- coding: utf-8 -*- import datetime
"""Memory Configuration Model - Backward Compatibility from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
This module provides backward compatibility for imports.
All classes have been moved to app.schemas.memory_config_schema.
DEPRECATED: Import from app.schemas.memory_config_schema instead. class MemoryConfig(Base):
""" """记忆配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "memory_config"
# Re-export for backward compatibility # 主键
from app.schemas.memory_config_schema import ( config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
ConfigurationError,
InvalidConfigError,
MemoryConfig,
MemoryConfigValidation,
ModelInactiveError,
ModelNotFoundError,
ModelValidation,
WorkspaceNotFoundError,
WorkspaceValidation,
validate_memory_config_data,
validate_model_data,
validate_workspace_data,
)
__all__ = [ # 基本信息
"ConfigurationError", config_name = Column(String, nullable=False, comment="配置名称")
"InvalidConfigError", config_desc = Column(String, nullable=True, comment="配置描述")
"MemoryConfig",
"MemoryConfigValidation", # 组织信息
"ModelInactiveError", workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
"ModelNotFoundError", end_user_id = Column(String, nullable=True, comment="组ID")
"ModelValidation", user_id = Column(String, nullable=True, comment="用户ID")
"WorkspaceNotFoundError", apply_id = Column(String, nullable=True, comment="应用ID")
"WorkspaceValidation",
"validate_memory_config_data", # 模型选择从workspace继承
"validate_model_data", llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
"validate_workspace_data", embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
] rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<MemoryConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -16,7 +16,7 @@ class PerceptualType(IntEnum):
CONVERSATION = 4 CONVERSATION = 4
class FileStorageType(IntEnum): class FileStorageService(IntEnum):
LOCAL = 1 LOCAL = 1
REMOTE = 2 REMOTE = 2

View File

@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""数据配置Repository模块 """记忆配置Repository模块
本模块提供data_config表的数据访问层使用SQLAlchemy ORM进行数据库操作 本模块提供memory_config表的数据访问层使用SQLAlchemy ORM进行数据库操作
包括CRUD操作和Neo4j Cypher查询常量 包括CRUD操作和Neo4j Cypher查询常量
Classes: Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作 MemoryConfigRepository: 记忆配置仓储类提供CRUD操作
""" """
import uuid import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig from app.models.memory_config_model import MemoryConfig
from app.schemas.memory_storage_schema import ( from app.schemas.memory_storage_schema import (
ConfigKey, ConfigKey,
ConfigParamsCreate, ConfigParamsCreate,
@@ -28,11 +29,11 @@ db_logger = get_db_logger()
# 获取配置专用日志器 # 获取配置专用日志器
config_logger = get_config_logger() config_logger = get_config_logger()
TABLE_NAME = "data_config" TABLE_NAME = "memory_config"
class DataConfigRepository: class MemoryConfigRepository:
"""数据配置Repository """记忆配置Repository
提供data_config表的数据访问方法包括 提供memory_config表的数据访问方法包括
- SQLAlchemy ORM 数据库操作 - SQLAlchemy ORM 数据库操作
- Neo4j Cypher查询常量 - Neo4j Cypher查询常量
""" """
@@ -41,48 +42,48 @@ class DataConfigRepository:
# Dialogue count by group # Dialogue count by group
SEARCH_FOR_DIALOGUE = """ SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
""" """
# Chunk count by group # Chunk count by group
SEARCH_FOR_CHUNK = """ SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
""" """
# Statement count by group # Statement count by group
SEARCH_FOR_STATEMENT = """ SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
""" """
# ExtractedEntity count by group # ExtractedEntity count by group
SEARCH_FOR_ENTITY = """ SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
""" """
# All counts by label and total # All counts by label and total
SEARCH_FOR_ALL = """ SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
UNION ALL UNION ALL
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
UNION ALL UNION ALL
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN 'Statement' AS Label, COUNT(n) AS Count
UNION ALL UNION ALL
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
UNION ALL UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
""" """
# Extracted entity details within group/app/user # Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """ SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity) MATCH (n:ExtractedEntity)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN n.entity_idx AS entity_idx, RETURN n.entity_idx AS entity_idx,
n.connect_strength AS connect_strength, n.connect_strength AS connect_strength,
n.description AS description, n.description AS description,
n.entity_type AS entity_type, n.entity_type AS entity_type,
n.name AS name, n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary, COALESCE(n.fact_summary, '') AS fact_summary,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id, n.apply_id AS apply_id,
n.user_id AS user_id, n.user_id AS user_id,
n.id AS id n.id AS id
@@ -91,9 +92,9 @@ class DataConfigRepository:
# Edges between extracted entities within group/app/user # Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """ SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
r.group_id AS group_id, r.end_user_id AS end_user_id,
r.apply_id AS apply_id, r.apply_id AS apply_id,
r.user_id AS user_id, r.user_id AS user_id,
elementId(r) AS rel_id, elementId(r) AS rel_id,
@@ -107,7 +108,7 @@ class DataConfigRepository:
@staticmethod @staticmethod
def update_reflection_config( def update_reflection_config(
db: Session, db: Session,
config_id: int, config_id: uuid.UUID,
enable_self_reflexion: bool, enable_self_reflexion: bool,
iteration_period: str, iteration_period: str,
reflexion_range: str, reflexion_range: str,
@@ -115,7 +116,7 @@ class DataConfigRepository:
reflection_model_id: str, reflection_model_id: str,
memory_verify: bool, memory_verify: bool,
quality_assessment: bool quality_assessment: bool
) -> DataConfig: ) -> MemoryConfig:
"""构建反思配置更新语句SQLAlchemy text() 命名参数) """构建反思配置更新语句SQLAlchemy text() 命名参数)
Args: Args:
@@ -130,28 +131,28 @@ class DataConfigRepository:
config_id: 配置ID config_id: 配置ID
Returns: Returns:
Data MemoryConfig
Raises: Raises:
ValueError: 没有字段需要更新时抛出 ValueError: 没有字段需要更新时抛出
""" """
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id) stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
data_config_obj = db.scalars(stmt).first() memory_config_obj = db.scalars(stmt).first()
if not data_config_obj: if not memory_config_obj:
raise BusinessException raise BusinessException
data_config_obj.enable_self_reflexion = enable_self_reflexion memory_config_obj.enable_self_reflexion = enable_self_reflexion
data_config_obj.iteration_period = iteration_period memory_config_obj.iteration_period = iteration_period
data_config_obj.reflexion_range = reflexion_range memory_config_obj.reflexion_range = reflexion_range
data_config_obj.baseline = baseline memory_config_obj.baseline = baseline
data_config_obj.reflection_model_id = reflection_model_id memory_config_obj.reflection_model_id = reflection_model_id
data_config_obj.memory_verify = memory_verify memory_config_obj.memory_verify = memory_verify
data_config_obj.quality_assessment = quality_assessment memory_config_obj.quality_assessment = quality_assessment
return data_config_obj return memory_config_obj
@staticmethod @staticmethod
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig: def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数) """构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args: Args:
@@ -162,13 +163,13 @@ class DataConfigRepository:
Tuple[str, Dict]: (SQL查询字符串, 参数字典) Tuple[str, Dict]: (SQL查询字符串, 参数字典)
""" """
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id) stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
data_config = db.scalars(stmt).first() memory_config = db.scalars(stmt).first()
if not data_config: if not memory_config:
raise RuntimeError("reflection config not found") raise RuntimeError("reflection config not found")
return data_config return memory_config
@staticmethod @staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig: def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数) """构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args: Args:
@@ -180,11 +181,11 @@ class DataConfigRepository:
""" """
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id) stmt = select(MemoryConfig).where(MemoryConfig.workspace_id == workspace_id)
data_config = db.scalars(stmt).first() memory_config = db.scalars(stmt).first()
if not data_config: if not memory_config:
raise RuntimeError("reflection config not found") raise RuntimeError("reflection config not found")
return data_config return memory_config
@staticmethod @staticmethod
@@ -208,20 +209,21 @@ class DataConfigRepository:
return query, params return query, params
@staticmethod @staticmethod
def create(db: Session, params: ConfigParamsCreate) -> DataConfig: def create(db: Session, params: ConfigParamsCreate) -> MemoryConfig:
"""创建数据配置 """创建记忆配置
Args: Args:
db: 数据库会话 db: 数据库会话
params: 配置参数创建模型 params: 配置参数创建模型
Returns: Returns:
DataConfig: 创建的配置对象 MemoryConfig: 创建的配置对象
""" """
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}") db_logger.debug(f"创建记忆配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
try: try:
db_config = DataConfig( db_config = MemoryConfig(
config_id=uuid.uuid4(),
config_name=params.config_name, config_name=params.config_name,
config_desc=params.config_desc, config_desc=params.config_desc,
workspace_id=params.workspace_id, workspace_id=params.workspace_id,
@@ -232,16 +234,16 @@ class DataConfigRepository:
db.add(db_config) db.add(db_config)
db.flush() # 获取自增ID但不提交事务 db.flush() # 获取自增ID但不提交事务
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})") db_logger.info(f"记忆配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
return db_config return db_config
except Exception as e: except Exception as e:
db.rollback() db.rollback()
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}") db_logger.error(f"创建记忆配置失败: {params.config_name} - {str(e)}")
raise raise
@staticmethod @staticmethod
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]: def update(db: Session, update: ConfigUpdate) -> Optional[MemoryConfig]:
"""更新基础配置 """更新基础配置
Args: Args:
@@ -249,17 +251,17 @@ class DataConfigRepository:
update: 配置更新模型 update: 配置更新模型
Returns: Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises: Raises:
ValueError: 没有字段需要更新时抛出 ValueError: 没有字段需要更新时抛出
""" """
db_logger.debug(f"更新数据配置: config_id={update.config_id}") db_logger.debug(f"更新记忆配置: config_id={update.config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config: if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}") db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None return None
# 更新字段 # 更新字段
@@ -277,17 +279,17 @@ class DataConfigRepository:
db.commit() db.commit()
db.refresh(db_config) db.refresh(db_config)
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})") db_logger.info(f"记忆配置更新成功: {db_config.config_name} (ID: {update.config_id})")
return db_config return db_config
except Exception as e: except Exception as e:
db.rollback() db.rollback()
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}") db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
raise raise
@staticmethod @staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]: def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
"""更新记忆萃取引擎配置 """更新记忆萃取引擎配置
Args: Args:
@@ -295,7 +297,7 @@ class DataConfigRepository:
update: 萃取配置更新模型 update: 萃取配置更新模型
Returns: Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises: Raises:
ValueError: 没有字段需要更新时抛出 ValueError: 没有字段需要更新时抛出
@@ -303,9 +305,9 @@ class DataConfigRepository:
db_logger.debug(f"更新萃取配置: config_id={update.config_id}") db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config: if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}") db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None return None
# 更新字段映射 # 更新字段映射
@@ -360,7 +362,7 @@ class DataConfigRepository:
raise raise
@staticmethod @staticmethod
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]: def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[MemoryConfig]:
"""更新遗忘引擎配置 """更新遗忘引擎配置
Args: Args:
@@ -368,7 +370,7 @@ class DataConfigRepository:
update: 遗忘配置更新模型 update: 遗忘配置更新模型
Returns: Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises: Raises:
ValueError: 没有字段需要更新时抛出 ValueError: 没有字段需要更新时抛出
@@ -376,9 +378,9 @@ class DataConfigRepository:
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}") db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config: if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}") db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None return None
# 更新字段 # 更新字段
@@ -408,7 +410,7 @@ class DataConfigRepository:
raise raise
@staticmethod @staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]: def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置 """获取萃取配置,通过主键查询某条配置
Args: Args:
@@ -421,7 +423,7 @@ class DataConfigRepository:
db_logger.debug(f"查询萃取配置: config_id={config_id}") db_logger.debug(f"查询萃取配置: config_id={config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config: if not db_config:
db_logger.debug(f"萃取配置不存在: config_id={config_id}") db_logger.debug(f"萃取配置不存在: config_id={config_id}")
return None return None
@@ -457,7 +459,7 @@ class DataConfigRepository:
raise raise
@staticmethod @staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]: def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置 """获取遗忘配置,通过主键查询某条配置
Args: Args:
@@ -470,7 +472,7 @@ class DataConfigRepository:
db_logger.debug(f"查询遗忘配置: config_id={config_id}") db_logger.debug(f"查询遗忘配置: config_id={config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config: if not db_config:
db_logger.debug(f"遗忘配置不存在: config_id={config_id}") db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
return None return None
@@ -489,39 +491,39 @@ class DataConfigRepository:
raise raise
@staticmethod @staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]: def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]:
"""根据ID获取数据配置 """根据ID获取记忆配置
Args: Args:
db: 数据库会话 db: 数据库会话
config_id: 配置ID config_id: 配置ID
Returns: Returns:
Optional[DataConfig]: 配置对象不存在则返回None Optional[MemoryConfig]: 配置对象不存在则返回None
""" """
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}") db_logger.debug(f"根据ID查询记忆配置: config_id={config_id}")
try: try:
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config: if config:
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})") db_logger.debug(f"记忆配置查询成功: {config.config_name} (ID: {config_id})")
else: else:
db_logger.debug(f"数据配置不存在: config_id={config_id}") db_logger.debug(f"记忆配置不存在: config_id={config_id}")
return config return config
except Exception as e: except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}") db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise raise
@staticmethod @staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]: def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
"""Get data config and its associated workspace information """Get memory config and its associated workspace information
Args: Args:
db: Database session db: Database session
config_id: Configuration ID config_id: Configuration ID
Returns: Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
Raises: Raises:
ValueError: Raised when config exists but workspace doesn't ValueError: Raised when config exists but workspace doesn't
@@ -541,19 +543,19 @@ class DataConfigRepository:
} }
) )
db_logger.debug(f"Querying data config and workspace: config_id={config_id}") db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
try: try:
# Use join query to get both config and workspace # Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join( result = db.query(MemoryConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first() ).filter(MemoryConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
if not result: if not result:
# Check if config exists but workspace is missing # Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config_only: if config_only:
if config_only.workspace_id is None: if config_only.workspace_id is None:
config_logger.error( config_logger.error(
@@ -566,7 +568,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms "elapsed_ms": elapsed_ms
} }
) )
db_logger.error(f"Data config {config_id} has no associated workspace ID") db_logger.error(f"Memory config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace") raise ValueError(f"Configuration {config_id} has no associated workspace")
else: else:
config_logger.error( config_logger.error(
@@ -579,7 +581,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms "elapsed_ms": elapsed_ms
} }
) )
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}") db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}") raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug( config_logger.debug(
@@ -591,7 +593,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms "elapsed_ms": elapsed_ms
} }
) )
db_logger.debug(f"Data config not found: config_id={config_id}") db_logger.debug(f"Memory config not found: config_id={config_id}")
return None return None
config, workspace = result config, workspace = result
@@ -611,7 +613,7 @@ class DataConfigRepository:
} }
) )
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}") db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace) return (config, workspace)
except ValueError: except ValueError:
@@ -633,10 +635,10 @@ class DataConfigRepository:
exc_info=True exc_info=True
) )
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}") db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
raise raise
@staticmethod @staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]: def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数 """获取所有配置参数
Args: Args:
@@ -644,17 +646,17 @@ class DataConfigRepository:
workspace_id: 工作空间ID用于过滤查询结果 workspace_id: 工作空间ID用于过滤查询结果
Returns: Returns:
List[DataConfig]: 配置列表 List[MemoryConfig]: 配置列表
""" """
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try: try:
query = db.query(DataConfig) query = db.query(MemoryConfig)
if workspace_id: if workspace_id:
query = query.filter(DataConfig.workspace_id == workspace_id) query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(DataConfig.updated_at)).all() configs = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs return configs
@@ -664,8 +666,8 @@ class DataConfigRepository:
raise raise
@staticmethod @staticmethod
def delete(db: Session, config_id: int) -> bool: def delete(db: Session, config_id: uuid.UUID) -> bool:
"""删除数据配置 """删除记忆配置
Args: Args:
db: 数据库会话 db: 数据库会话
@@ -674,22 +676,22 @@ class DataConfigRepository:
Returns: Returns:
bool: 删除成功返回True配置不存在返回False bool: 删除成功返回True配置不存在返回False
""" """
db_logger.debug(f"删除数据配置: config_id={config_id}") db_logger.debug(f"删除记忆配置: config_id={config_id}")
try: try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config: if not db_config:
db_logger.warning(f"数据配置不存在: config_id={config_id}") db_logger.warning(f"记忆配置不存在: config_id={config_id}")
return False return False
db.delete(db_config) db.delete(db_config)
db.commit() db.commit()
db_logger.info(f"数据配置删除成功: config_id={config_id}") db_logger.info(f"记忆配置删除成功: config_id={config_id}")
return True return True
except Exception as e: except Exception as e:
db.rollback() db.rollback()
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}") db_logger.error(f"删除记忆配置失败: config_id={config_id} - {str(e)}")
raise raise

View File

@@ -6,7 +6,7 @@ from sqlalchemy import and_, desc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageService
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
db_logger = get_db_logger() db_logger = get_db_logger()
@@ -28,7 +28,7 @@ class MemoryPerceptualRepository:
file_ext: str, file_ext: str,
summary: Optional[str] = None, summary: Optional[str] = None,
meta_data: Optional[dict] = None, meta_data: Optional[dict] = None,
storage_service: FileStorageType = FileStorageType.LOCAL storage_service: FileStorageService = FileStorageService.LOCAL
) -> MemoryPerceptualModel: ) -> MemoryPerceptualModel:

View File

@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"id": stable_edge_id, "id": stable_edge_id,
"source": chunk.id, "source": chunk.id,
"target": stmt.id, "target": stmt.id,
"group_id": getattr(stmt, 'group_id', None), "end_user_id": getattr(stmt, 'end_user_id', None),
"user_id":getattr(stmt, 'user_id', None), "user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None), "apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
edges.append({ edges.append({
"summary_id": s.id, "summary_id": s.id,
"chunk_id": chunk_id, "chunk_id": chunk_id,
"group_id": s.group_id, "end_user_id": s.end_user_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector): async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database.""" """Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n") result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully") print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
for dialogue in dialogues: for dialogue in dialogues:
flattened_dialogues.append({ flattened_dialogues.append({
"id": dialogue.id, "id": dialogue.id,
"group_id": dialogue.group_id, "end_user_id": dialogue.end_user_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"run_id": dialogue.run_id, "run_id": dialogue.run_id,
"ref_id": dialogue.ref_id, "ref_id": dialogue.ref_id,
"name": dialogue.name, "name": dialogue.name,
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
flattened_statement = { flattened_statement = {
"id": statement.id, "id": statement.id,
"name": statement.name, "name": statement.name,
"group_id": statement.group_id, "end_user_id": statement.end_user_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"run_id": statement.run_id, "run_id": statement.run_id,
"chunk_id": statement.chunk_id, "chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(), # "created_at": statement.created_at.isoformat(),
@@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
flattened_chunk = { flattened_chunk = {
"id": chunk.id, "id": chunk.id,
"name": chunk.name, "name": chunk.name,
"group_id": chunk.group_id, "end_user_id": chunk.end_user_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"run_id": chunk.run_id, "run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None, "created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
@@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
flattened.append({ flattened.append({
"id": s.id, "id": s.id,
"name": s.name, "name": s.name,
"group_id": s.group_id, "end_user_id": s.end_user_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
Example: Example:
>>> results = await repository.find( >>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"}, ... {"end_user_id": "group_123", "user_id": "user_456"},
... limit=50 ... limit=50
... ) ... )
""" """

View File

@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id}) MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id), SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id, n.end_user_id = dialogue.end_user_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.run_id = dialogue.run_id, n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id, n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at, n.created_at = dialogue.created_at,
@@ -22,9 +20,7 @@ SET s += {
id: statement.id, id: statement.id,
run_id: statement.run_id, run_id: statement.run_id,
chunk_id: statement.chunk_id, chunk_id: statement.chunk_id,
group_id: statement.group_id, end_user_id: statement.end_user_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
stmt_type: statement.stmt_type, stmt_type: statement.stmt_type,
statement: statement.statement, statement: statement.statement,
emotion_intensity: statement.emotion_intensity, emotion_intensity: statement.emotion_intensity,
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
SET c += { SET c += {
id: chunk.id, id: chunk.id,
name: chunk.name, name: chunk.name,
group_id: chunk.group_id, end_user_id: chunk.end_user_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
run_id: chunk.run_id, run_id: chunk.run_id,
created_at: chunk.created_at, created_at: chunk.created_at,
expired_at: chunk.expired_at, expired_at: chunk.expired_at,
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
UNWIND $entities AS entity UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id}) MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END, SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END, e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END, e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships # Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """ ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id // Match entities by stable id within end_user_id, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id}) MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id}) MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints // Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate, SET r.predicate = rel.predicate,
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at, r.expired_at = rel.expired_at,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.group_id = rel.group_id r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id}) MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += { SET e += {
name: entity.name, name: entity.name,
group_id: entity.group_id, end_user_id: entity.end_user_id,
run_id: entity.run_id, run_id: entity.run_id,
description: entity.description, description: entity.description,
chunk_id: entity.chunk_id, chunk_id: entity.chunk_id,
@@ -175,11 +167,11 @@ RETURN e.id AS id
SAVE_STRONG_TRIPLE_ENTITIES = """ SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id}) MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id} SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag // Independent strong flag
SET s.is_strong = true SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id}) MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id} SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag // Independent strong flag
SET o.is_strong = true SET o.is_strong = true
""" """
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
// 仅按端点去重,关系属性可更新 // 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement) MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id, SET e.uuid = edge.id,
e.group_id = edge.group_id, e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at, e.created_at = edge.created_at,
e.expired_at = edge.expired_at e.expired_at = edge.expired_at
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id}) MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id, SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id, e.run_id = edge.run_id,
e.created_at = edge.created_at, e.created_at = edge.created_at,
e.expired_at = edge.expired_at e.expired_at = edge.expired_at
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
STATEMENT_ENTITY_EDGE_SAVE = """ STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements // Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id}) MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id // Entities are shared across runs within end_user_id; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id}) MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for same endpoints // Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity) MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id, SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at, r.expired_at = rel.expired_at,
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id) AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, RETURN e.id AS id,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.end_user_id AS end_user_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score, COALESCE(e.importance_score, 0.5) AS importance_score,
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id) AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.expired_at AS expired_at, s.expired_at AS expired_at,
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id) AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value, COALESCE(c.activation_value, 0.5) AS activation_value,
@@ -292,12 +283,12 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD = """ SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.expired_at AS expired_at, s.expired_at AS expired_at,
@@ -316,15 +307,13 @@ LIMIT $limit
# 查询实体名称包含指定字符串的实体 # 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """ SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id) WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id, RETURN e.id AS id,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.end_user_id AS end_user_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at, e.created_at AS created_at,
e.expired_at AS expired_at, e.expired_at AS expired_at,
e.entity_idx AS entity_idx, e.entity_idx AS entity_idx,
@@ -347,11 +336,11 @@ LIMIT $limit
SEARCH_CHUNKS_BY_CONTENT = """ SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id) WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number, c.sequence_number AS sequence_number,
@@ -413,10 +402,10 @@ LIMIT $limit
SEARCH_DIALOGUE_BY_DIALOG_ID = """ SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue) MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id) WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
AND d.id = $dialog_id AND d.id = $dialog_id
RETURN d.id AS dialog_id, RETURN d.id AS dialog_id,
d.group_id AS group_id, d.end_user_id AS end_user_id,
d.content AS content, d.content AS content,
d.created_at AS created_at, d.created_at AS created_at,
d.expired_at AS expired_at d.expired_at AS expired_at
@@ -426,10 +415,10 @@ LIMIT $limit
SEARCH_CHUNK_BY_CHUNK_ID = """ SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk) MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id) WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
AND c.id = $chunk_id AND c.id = $chunk_id
RETURN c.id AS chunk_id, RETURN c.id AS chunk_id,
c.group_id AS group_id, c.end_user_id AS end_user_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
c.created_at AS created_at, c.created_at AS created_at,
@@ -441,18 +430,14 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_TEMPORAL = """ SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement) MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date)) AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date))) AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
@@ -468,9 +453,7 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id) WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id, RETURN s.id AS id,
s.statement AS statement, s.statement AS statement,
s.group_id AS group_id, s.end_user_id AS end_user_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id, s.chunk_id AS chunk_id,
s.created_at AS created_at, s.created_at AS created_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
@@ -499,15 +480,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_CREATED_AT = """ SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -519,15 +496,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_VALID_AT = """ SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -539,15 +512,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_CREATED_AT = """ SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -559,15 +528,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_CREATED_AT = """ SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at)) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -579,15 +544,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_VALID_AT = """ SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -599,15 +560,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_VALID_AT = """ SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement) MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id) WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at)) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id, RETURN n.id AS id,
n.statement AS statement, n.statement AS statement,
n.group_id AS group_id, n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id, n.chunk_id AS chunk_id,
n.created_at AS created_at, n.created_at AS created_at,
n.valid_at AS valid_at, n.valid_at AS valid_at,
@@ -665,18 +622,18 @@ LIMIT $limit
# 根据id修改句子的invalid_at的值 # 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """ UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id}) MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at SET n.invalid_at = $new_invalid_at
""" """
# MemorySummary keyword search using fulltext index # MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id) WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id, RETURN m.id AS id,
m.name AS name, m.name AS name,
m.group_id AS group_id, m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id, m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding) CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id) AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id, RETURN m.id AS id,
m.name AS name, m.name AS name,
m.group_id AS group_id, m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id, m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
SET m += { SET m += {
id: summary.id, id: summary.id,
name: summary.name, name: summary.name,
group_id: summary.group_id, end_user_id: summary.end_user_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
run_id: summary.run_id, run_id: summary.run_id,
created_at: summary.created_at, created_at: summary.created_at,
expired_at: summary.expired_at, expired_at: summary.expired_at,
@@ -745,7 +700,7 @@ MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id}) MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id}) MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s) MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.group_id = e.group_id, SET r.end_user_id = e.end_user_id,
r.run_id = e.run_id, r.run_id = e.run_id,
r.created_at = e.created_at, r.created_at = e.created_at,
r.expired_at = e.expired_at r.expired_at = e.expired_at
@@ -774,7 +729,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id, source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at, valid_at: rel.valid_at,
invalid_at: rel.invalid_at, invalid_at: rel.invalid_at,
group_id: rel.group_id, end_user_id: rel.end_user_id,
user_id: rel.user_id, user_id: rel.user_id,
apply_id: rel.apply_id, apply_id: rel.apply_id,
run_id: rel.run_id, run_id: rel.run_id,
@@ -796,7 +751,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id, source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at, valid_at: rel.valid_at,
invalid_at: rel.invalid_at, invalid_at: rel.invalid_at,
group_id: rel.group_id, end_user_id: rel.end_user_id,
user_id: rel.user_id, user_id: rel.user_id,
apply_id: rel.apply_id, apply_id: rel.apply_id,
run_id: rel.run_id, run_id: rel.run_id,
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
neo4j_statement_part = ''' neo4j_statement_part = '''
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D') AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN RETURN
n.statement as statement_name, n.statement as statement_name,
@@ -824,7 +779,7 @@ RETURN
''' '''
neo4j_statement_all = ''' neo4j_statement_all = '''
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
RETURN RETURN
n.statement as statement_name, n.statement as statement_name,
n.id as statement_id n.id as statement_id
@@ -832,7 +787,7 @@ RETURN
''' '''
neo4j_query_part = """ neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity) MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D') AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
@@ -853,7 +808,7 @@ neo4j_query_part = """
""" """
neo4j_query_all = """ neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity) MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}" WHERE n.end_user_id = "{}"
WITH DISTINCT m WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN RETURN
@@ -1027,14 +982,14 @@ RETURN DISTINCT
Memory_Space_User=""" Memory_Space_User="""
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id return DISTINCT elementId(m) as id
""" """
Memory_Space_Entity=""" Memory_Space_Entity="""
MATCH (n)-[]-(m) MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person" WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN RETURN
DISTINCT m.name as name,m.group_id as group_id DISTINCT m.name as name,m.end_user_id as end_user_id
""" """
Memory_Space_Associative=""" Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h) MATCH (u)-[]-(x)-[]-(h)

View File

@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储 """对话仓储
管理对话节点的创建、查询、更新和删除操作。 管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。 提供按end_user_id、user_id、ref_id等条件查询对话的方法。
Attributes: Attributes:
connector: Neo4j连接器实例 connector: Neo4j连接器实例
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
return DialogueNode(**n) return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]: async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话 """根据end_user_id查询对话
Args: Args:
group_id: 组ID end_user_id: 组ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
Returns: Returns:
List[DialogueNode]: 对话列表 List[DialogueNode]: 对话列表
""" """
return await self.find({"group_id": group_id}, limit=limit) return await self.find({"end_user_id": end_user_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]: async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话 """根据user_id查询对话
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_group_and_user( async def find_by_group_and_user(
self, self,
group_id: str, end_user_id: str,
user_id: str, user_id: str,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""根据group_id和user_id查询对话 """根据end_user_id和user_id查询对话
Args: Args:
group_id: 组ID end_user_id: 组ID
user_id: 用户ID user_id: 用户ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
List[DialogueNode]: 对话列表 List[DialogueNode]: 对话列表
""" """
return await self.find( return await self.find(
{"group_id": group_id, "user_id": user_id}, {"end_user_id": end_user_id, "user_id": user_id},
limit=limit limit=limit
) )
async def find_recent_dialogs( async def find_recent_dialogs(
self, self,
group_id: str, end_user_id: str,
days: int = 7, days: int = 7,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""查询最近的对话 """查询最近的对话
Args: Args:
group_id: 组ID end_user_id: 组ID
days: 查询最近多少天的对话 days: 查询最近多少天的对话
limit: 返回结果的最大数量 limit: 返回结果的最大数量
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}}) AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n RETURN n
ORDER BY n.created_at DESC ORDER BY n.created_at DESC
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
""" """
results = await self.connector.execute_query( results = await self.connector.execute_query(
query, query,
group_id=group_id, end_user_id=end_user_id,
days=days, days=days,
limit=limit limit=limit
) )
@@ -164,22 +164,22 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_config_and_group( async def find_by_config_and_group(
self, self,
config_id: str, config_id: str,
group_id: str, end_user_id: str,
limit: int = 100 limit: int = 100
) -> List[DialogueNode]: ) -> List[DialogueNode]:
"""根据config_id和group_id查询对话 """根据config_id和end_user_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。 支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args: Args:
config_id: 配置ID config_id: 配置ID
group_id: 组ID end_user_id: 组ID
limit: 返回结果的最大数量 limit: 返回结果的最大数量
Returns: Returns:
List[DialogueNode]: 对话列表 List[DialogueNode]: 对话列表
""" """
return await self.find( return await self.find(
{"config_id": config_id, "group_id": group_id}, {"config_id": config_id, "end_user_id": end_user_id},
limit=limit limit=limit
) )

View File

@@ -40,7 +40,7 @@ class EmotionRepository:
async def get_emotion_tags( async def get_emotion_tags(
self, self,
group_id: str, end_user_id: str,
emotion_type: Optional[str] = None, emotion_type: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
@@ -51,7 +51,7 @@ class EmotionRepository:
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串 start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串 end_date: 可选的结束日期ISO格式字符串
@@ -65,8 +65,8 @@ class EmotionRepository:
- avg_intensity: 平均强度 - avg_intensity: 平均强度
""" """
# 构建查询条件 # 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"] where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type: if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type") where_clauses.append("s.emotion_type = $emotion_type")
@@ -119,7 +119,7 @@ class EmotionRepository:
async def get_emotion_wordcloud( async def get_emotion_wordcloud(
self, self,
group_id: str, end_user_id: str,
emotion_type: Optional[str] = None, emotion_type: Optional[str] = None,
limit: int = 50 limit: int = 50
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@@ -128,7 +128,7 @@ class EmotionRepository:
查询情绪关键词及其频率,用于生成词云可视化。 查询情绪关键词及其频率,用于生成词云可视化。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤 emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量 limit: 返回关键词的最大数量
@@ -140,8 +140,8 @@ class EmotionRepository:
- avg_intensity: 平均强度 - avg_intensity: 平均强度
""" """
# 构建查询条件 # 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"] where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type: if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type") where_clauses.append("s.emotion_type = $emotion_type")
@@ -186,7 +186,7 @@ class EmotionRepository:
async def get_emotions_in_range( async def get_emotions_in_range(
self, self,
group_id: str, end_user_id: str,
time_range: str = "30d" time_range: str = "30d"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据 """获取时间范围内的情绪数据
@@ -194,7 +194,7 @@ class EmotionRepository:
查询指定时间范围内的所有情绪数据,用于健康指数计算。 查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args: Args:
group_id: 用户组ID宿主ID end_user_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d time_range: 时间范围7d/30d/90d
Returns: Returns:
@@ -214,7 +214,7 @@ class EmotionRepository:
# 优化的 Cypher 查询:使用字符串比较避免时区问题 # 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """ query = """
MATCH (s:Statement) MATCH (s:Statement)
WHERE s.group_id = $group_id WHERE s.end_user_id = $end_user_id
AND s.emotion_type IS NOT NULL AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date AND s.created_at >= $start_date
RETURN s.id as statement_id, RETURN s.id as statement_id,
@@ -227,7 +227,7 @@ class EmotionRepository:
try: try:
results = await self.connector.execute_query( results = await self.connector.execute_query(
query, query,
group_id=group_id, end_user_id=end_user_id,
start_date=start_date start_date=start_date
) )
formatted_results = [ formatted_results = [

View File

@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
'created_at': edge.created_at.isoformat(), 'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(), 'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id, 'run_id': edge.run_id,
'group_id': edge.group_id, 'end_user_id': edge.end_user_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
} }
all_relationships.append(relationship) all_relationships.append(relationship)
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
"id": edge.id, "id": edge.id,
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"group_id": edge.group_id, "end_user_id": edge.end_user_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id, "run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
edge_data = { edge_data = {
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"group_id": edge.group_id, "end_user_id": edge.end_user_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id, "run_id": edge.run_id,
"connect_strength": edge.connect_strength, "connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,

View File

@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
connector: Neo4jConnector, connector: Neo4jConnector,
nodes: List[Dict[str, Any]], nodes: List[Dict[str, Any]],
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_retries: int = 3 max_retries: int = 3
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
connector: Neo4j连接器 connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段 nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选 end_user_id: 组ID可选
max_retries: 最大重试次数 max_retries: 最大重试次数
Returns: Returns:
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
updated_nodes = await access_manager.record_batch_access( updated_nodes = await access_manager.record_batch_access(
node_ids=unique_node_ids, node_ids=unique_node_ids,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
logger.info( logger.info(
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
async def _update_search_results_activation( async def _update_search_results_activation(
connector: Neo4jConnector, connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
更新搜索结果中所有知识节点的激活值 更新搜索结果中所有知识节点的激活值
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
Args: Args:
connector: Neo4j连接器 connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表 results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果 Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
connector=connector, connector=connector,
nodes=results[key], nodes=results[key],
node_label=label, node_label=label,
group_id=group_id end_user_id=end_user_id
) )
) )
update_keys.append(key) update_keys.append(key)
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
q: str, q: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -236,7 +236,7 @@ async def search_graph(
Args: Args:
connector: Neo4j connector connector: Neo4j connector
q: Query text q: Query text
group_id: Optional group filter end_user_id: Optional group filter
limit: Max results per category limit: Max results per category
include: List of categories to search (default: all) include: List of categories to search (default: all)
@@ -254,7 +254,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("statements") task_keys.append("statements")
@@ -263,7 +263,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("entities") task_keys.append("entities")
@@ -272,7 +272,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("chunks") task_keys.append("chunks")
@@ -281,7 +281,7 @@ async def search_graph(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q, q=q,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
@@ -315,7 +315,7 @@ async def search_graph(
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
@@ -325,7 +325,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"], include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -337,7 +337,7 @@ async def search_graph_by_embedding(
- Computes query embedding with the provided embedder_client - Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher - Ranks by cosine similarity in Cypher
- Filters by group_id if provided - Filters by end_user_id if provided
- Returns up to 'limit' per included type - Returns up to 'limit' per included type
""" """
import time import time
@@ -346,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []} return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -361,7 +361,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH, STATEMENT_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("statements") task_keys.append("statements")
@@ -371,7 +371,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("chunks") task_keys.append("chunks")
@@ -381,7 +381,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("entities") task_keys.append("entities")
@@ -391,7 +391,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding, embedding=embedding,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
@@ -400,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -435,7 +435,7 @@ async def search_graph_by_embedding(
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
update_time = time.time() - update_start update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
@@ -445,7 +445,7 @@ async def search_graph_by_embedding(
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
entities: List[Dict[str, Any]], entities: List[Dict[str, Any]],
use_contains_fallback: bool = True, use_contains_fallback: bool = True,
batch_size: int = 500, batch_size: int = 500,
@@ -453,7 +453,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选; - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...] - 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤; - 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。 - `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
@@ -477,7 +477,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query( rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=name, q=name,
group_id=group_id, end_user_id=end_user_id,
limit=100, limit=100,
) )
except Exception: except Exception:
@@ -501,7 +501,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query( rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
q=name.lower(), q=name.lower(),
group_id=group_id, end_user_id=end_user_id,
limit=100, limit=100,
) )
for r in rows: for r in rows:
@@ -532,9 +532,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal( async def search_graph_by_keyword_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -547,32 +545,30 @@ async def search_graph_by_keyword_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date - Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
logger.warning(f"query_text cannot be empty") print(f"query_text不能为空")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
valid_date=valid_date, valid_date=valid_date,
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
logger.debug(f"Temporal keyword search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
@@ -580,9 +576,7 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal( async def search_graph_by_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -595,14 +589,12 @@ async def search_graph_by_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL, SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
valid_date=valid_date, valid_date=valid_date,
@@ -610,16 +602,16 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
logger.debug(f"Temporal search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
@@ -628,23 +620,23 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
connector: Neo4jConnector, connector: Neo4jConnector,
dialog_id: str, dialog_id: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Dialogues. Temporal search across Dialogues.
- Matches dialogues with dialog_id - Matches dialogues with dialog_id
- Optionally filters by group_id - Optionally filters by end_user_id
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
logger.warning(f"dialog_id cannot be empty") print(f"dialog_id不能为空")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id, end_user_id=end_user_id,
dialog_id=dialog_id, dialog_id=dialog_id,
limit=limit, limit=limit,
) )
@@ -654,15 +646,15 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id( async def search_graph_by_chunk_id(
connector: Neo4jConnector, connector: Neo4jConnector,
chunk_id : str, chunk_id : str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
logger.warning(f"chunk_id cannot be empty") print(f"chunk_id不能为空")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id, end_user_id=end_user_id,
chunk_id=chunk_id, chunk_id=chunk_id,
limit=limit, limit=limit,
) )
@@ -671,9 +663,9 @@ async def search_graph_by_chunk_id(
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -683,37 +675,37 @@ async def search_graph_by_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -723,37 +715,37 @@ async def search_graph_by_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -763,37 +755,37 @@ async def search_graph_g_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT, SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -803,37 +795,37 @@ async def search_graph_g_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None, created_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -843,37 +835,37 @@ async def search_graph_l_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None, valid_at: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -883,28 +875,28 @@ async def search_graph_l_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id, end_user_id=end_user_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
logger.debug(f"Search results: {len(statements)} statements found") print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
connector=connector, connector=connector,
results=results, results=results,
group_id=group_id end_user_id=end_user_id
) )
return results return results

View File

@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository """Memory Summary Repository
Manages CRUD operations for MemorySummary nodes. Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges. Provides methods to query summaries by end_user_id, user_id, and time ranges.
Attributes: Attributes:
connector: Neo4j connector instance connector: Neo4j connector instance
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
return dict(n) return dict(n)
async def find_by_group_id( async def find_by_end_user_id(
self, self,
group_id: str, end_user_id: str,
limit: int = 1000, limit: int = 1000,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id """Query memory summaries by end_user_id
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
limit: Maximum number of results to return limit: Maximum number of results to return
start_date: Optional start date filter start_date: Optional start date filter
end_date: Optional end date filter end_date: Optional end date filter
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
""" """
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
# Add date range filters if provided # Add date range filters if provided
if start_date: if start_date:
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_group_and_user( async def find_by_group_and_user(
self, self,
group_id: str, end_user_id: str,
user_id: str, user_id: str,
limit: int = 1000, limit: int = 1000,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id """Query memory summaries by both end_user_id and user_id
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
user_id: User ID to filter by user_id: User ID to filter by
limit: Maximum number of results to return limit: Maximum number of results to return
start_date: Optional start date filter start_date: Optional start date filter
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
""" """
params = {"group_id": group_id, "user_id": user_id, "limit": limit} params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided # Add date range filters if provided
if start_date: if start_date:
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_recent_summaries( async def find_recent_summaries(
self, self,
group_id: str, end_user_id: str,
days: int = 7, days: int = 7,
limit: int = 1000 limit: int = 1000
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query recent memory summaries """Query recent memory summaries
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
days: Number of recent days to query days: Number of recent days to query
limit: Maximum number of results to return limit: Maximum number of results to return
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}}) AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n RETURN n
ORDER BY n.created_at DESC ORDER BY n.created_at DESC
@@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query( results = await self.connector.execute_query(
query, query,
group_id=group_id, end_user_id=end_user_id,
days=days, days=days,
limit=limit limit=limit
) )
@@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_content_keywords( async def find_by_content_keywords(
self, self,
group_id: str, end_user_id: str,
keywords: List[str], keywords: List[str],
limit: int = 100 limit: int = 100
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query memory summaries by content keywords """Query memory summaries by content keywords
Args: Args:
group_id: Group ID to filter by end_user_id: Group ID to filter by
keywords: List of keywords to search for in content keywords: List of keywords to search for in content
limit: Maximum number of results to return limit: Maximum number of results to return
@@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
""" """
# Build keyword search conditions # Build keyword search conditions
keyword_conditions = [] keyword_conditions = []
params = {"group_id": group_id, "limit": limit} params = {"end_user_id": end_user_id, "limit": limit}
for i, keyword in enumerate(keywords): for i, keyword in enumerate(keywords):
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})") keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
@@ -243,7 +243,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
AND ({keyword_filter}) AND ({keyword_filter})
RETURN n RETURN n
ORDER BY n.created_at DESC ORDER BY n.created_at DESC
@@ -253,21 +253,21 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results] return [self._map_to_dict(r) for r in results]
async def get_summary_count_by_group(self, group_id: str) -> int: async def get_summary_count_by_group(self, end_user_id: str) -> int:
"""Get count of memory summaries for a group """Get count of memory summaries for a group
Args: Args:
group_id: Group ID to count summaries for end_user_id: Group ID to count summaries for
Returns: Returns:
int: Number of memory summaries int: Number of memory summaries
""" """
query = f""" query = f"""
MATCH (n:{self.node_label}) MATCH (n:{self.node_label})
WHERE n.group_id = $group_id WHERE n.end_user_id = $end_user_id
RETURN count(n) as count RETURN count(n) as count
""" """
results = await self.connector.execute_query(query, group_id=group_id) results = await self.connector.execute_query(query, end_user_id=end_user_id)
return results[0]['count'] if results else 0 return results[0]['count'] if results else 0

View File

@@ -70,11 +70,7 @@ class Neo4jConnector:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典 List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example: Example:
>>> connector = Neo4jConnector()
>>> results = await connector.execute_query(
... "MATCH (n:Person {name: $name}) RETURN n",
... name="Alice"
... )
""" """
result = await self.driver.execute_query( result = await self.driver.execute_query(
query, query,
@@ -98,17 +94,7 @@ class Neo4jConnector:
Any: 事务函数的返回值 Any: 事务函数的返回值
Example: Example:
>>> async def create_node(tx, name):
... result = await tx.run(
... "CREATE (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_write_transaction(
... create_node, name="Alice"
... )
""" """
async with self.driver.session(database="neo4j") as session: async with self.driver.session(database="neo4j") as session:
return await session.execute_write(transaction_func, **kwargs) return await session.execute_write(transaction_func, **kwargs)
@@ -126,45 +112,33 @@ class Neo4jConnector:
Any: 事务函数的返回值 Any: 事务函数的返回值
Example: Example:
>>> async def get_node(tx, name):
... result = await tx.run(
... "MATCH (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_read_transaction(
... get_node, name="Alice"
... )
""" """
async with self.driver.session(database="neo4j") as session: async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs) return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str): async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据 """删除指定组的所有数据
删除所有属于指定group_id的节点和边。 删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。 这是一个危险操作,会永久删除数据。
Args: Args:
group_id: 要删除的组ID end_user_id: 要删除的组ID
Example: Example:
>>> connector = Neo4jConnector()
>>> await connector.delete_group("group_123")
Group group_123 deleted. Group group_123 deleted.
""" """
# 删除节点DETACH DELETE会同时删除相关的边 # 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query( await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n", "MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j", database="neo4j",
group_id=group_id end_user_id=end_user_id
) )
# 删除独立的边(如果有的话) # 删除独立的边(如果有的话)
await self.driver.execute_query( await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r", "MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j", database="neo4j",
group_id=group_id end_user_id=end_user_id
) )
print(f"Group {group_id} deleted.") print(f"Group {end_user_id} deleted.")

View File

@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储 """陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。 管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。 提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
Attributes: Attributes:
connector: Neo4j连接器实例 connector: Neo4j连接器实例

View File

@@ -299,6 +299,18 @@ class AppRelease(BaseModel):
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v if v is not None else {}
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None

View File

@@ -1,11 +1,12 @@
"""情绪分析相关的请求和响应模型""" """情绪分析相关的请求和响应模型"""
from typing import Optional from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class EmotionTagsRequest(BaseModel): class EmotionTagsRequest(BaseModel):
"""获取情绪标签统计请求""" """获取情绪标签统计请求"""
group_id: str = Field(..., description="组ID") end_user_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral") emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01") start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01")
end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31") end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31")
@@ -14,14 +15,14 @@ class EmotionTagsRequest(BaseModel):
class EmotionWordcloudRequest(BaseModel): class EmotionWordcloudRequest(BaseModel):
"""获取情绪词云数据请求""" """获取情绪词云数据请求"""
group_id: str = Field(..., description="组ID") end_user_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral") emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
limit: int = Field(50, ge=1, le=200, description="返回词语数量") limit: int = Field(50, ge=1, le=200, description="返回词语数量")
class EmotionHealthRequest(BaseModel): class EmotionHealthRequest(BaseModel):
"""获取情绪健康指数请求""" """获取情绪健康指数请求"""
group_id: str = Field(..., description="组ID") end_user_id: str = Field(..., description="组ID")
time_range: str = Field("30d", description="时间范围7d/30d/90d") time_range: str = Field("30d", description="时间范围7d/30d/90d")
@@ -29,8 +30,8 @@ class EmotionHealthRequest(BaseModel):
class EmotionSuggestionsRequest(BaseModel): class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求""" """获取个性化情绪建议请求"""
group_id: str = Field(..., description="组ID") end_user_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型") config_id: Optional[UUID] = Field(None, description="配置ID用于指定LLM模型")
class EmotionGenerateSuggestionsRequest(BaseModel): class EmotionGenerateSuggestionsRequest(BaseModel):

View File

@@ -7,11 +7,11 @@ class UserInput(BaseModel):
message: str message: str
history: list[dict] history: list[dict]
search_switch: str search_switch: str
group_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class Write_UserInput(BaseModel): class Write_UserInput(BaseModel):
messages: list[dict] messages: list[dict]
group_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None

View File

@@ -35,7 +35,7 @@ class ConfigurationError(Exception):
def __init__( def __init__(
self, self,
message: str, message: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
): ):
@@ -72,7 +72,7 @@ class WorkspaceNotFoundError(ConfigurationError):
def __init__( def __init__(
self, self,
workspace_id: UUID, workspace_id: UUID,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
if message is None: if message is None:
@@ -89,7 +89,7 @@ class ModelNotFoundError(ConfigurationError):
self, self,
model_id: Union[str, UUID], model_id: Union[str, UUID],
model_type: str, model_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
@@ -112,7 +112,7 @@ class ModelInactiveError(ConfigurationError):
model_id: Union[str, UUID], model_id: Union[str, UUID],
model_name: str, model_name: str,
model_type: str, model_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
@@ -136,7 +136,7 @@ class InvalidConfigError(ConfigurationError):
message: str, message: str,
field_name: Optional[str] = None, field_name: Optional[str] = None,
invalid_value: Optional[Any] = None, invalid_value: Optional[Any] = None,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
): ):
context = {} context = {}
@@ -155,7 +155,7 @@ class InvalidConfigError(ConfigurationError):
class MemoryConfigValidation(BaseModel): class MemoryConfigValidation(BaseModel):
"""Pydantic model for validating memory configuration data from database.""" """Pydantic model for validating memory configuration data from database."""
config_id: int = Field(..., gt=0, description="Configuration ID must be positive") config_id: UUID = Field(..., description="Configuration ID (UUID)")
config_name: str = Field(..., min_length=1, max_length=255) config_name: str = Field(..., min_length=1, max_length=255)
workspace_id: UUID = Field(..., description="Workspace UUID") workspace_id: UUID = Field(..., description="Workspace UUID")
workspace_name: str = Field(..., min_length=1, max_length=255) workspace_name: str = Field(..., min_length=1, max_length=255)
@@ -275,7 +275,7 @@ class ModelValidation(BaseModel):
def validate_memory_config_data( def validate_memory_config_data(
config_data: Dict[str, Any], config_id: Optional[int] = None config_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> MemoryConfigValidation: ) -> MemoryConfigValidation:
"""Validate memory configuration data using Pydantic model.""" """Validate memory configuration data using Pydantic model."""
try: try:
@@ -302,7 +302,7 @@ def validate_memory_config_data(
def validate_workspace_data( def validate_workspace_data(
workspace_data: Dict[str, Any], config_id: Optional[int] = None workspace_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> WorkspaceValidation: ) -> WorkspaceValidation:
"""Validate workspace data using Pydantic model.""" """Validate workspace data using Pydantic model."""
try: try:
@@ -331,7 +331,7 @@ def validate_workspace_data(
def validate_model_data( def validate_model_data(
model_data: Dict[str, Any], config_id: Optional[int] = None model_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> ModelValidation: ) -> ModelValidation:
"""Validate model data using Pydantic model.""" """Validate model data using Pydantic model."""
try: try:
@@ -364,7 +364,7 @@ def validate_model_data(
class MemoryConfig: class MemoryConfig:
"""Immutable memory configuration loaded from database.""" """Immutable memory configuration loaded from database."""
config_id: int config_id: UUID
config_name: str config_name: str
workspace_id: UUID workspace_id: UUID
workspace_name: str workspace_name: str

View File

@@ -4,7 +4,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.models.memory_perceptual_model import PerceptualType, FileStorageType from app.models.memory_perceptual_model import PerceptualType, FileStorageService
class PerceptualFilter(BaseModel): class PerceptualFilter(BaseModel):
@@ -38,12 +38,14 @@ class PerceptualMemoryItem(BaseModel):
"""感知记忆项""" """感知记忆项"""
id: uuid.UUID = Field(..., description="Unique memory ID") id: uuid.UUID = Field(..., description="Unique memory ID")
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video") perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
storage_service: FileStorageService = Field(..., description="Storage service for file")
file_path: str = Field(..., description="File path in the storage service") file_path: str = Field(..., description="File path in the storage service")
file_ext: str = Field(..., description="File extension")
file_name: str = Field(..., description="File name") file_name: str = Field(..., description="File name")
file_ext: str = Field(..., description="File extension")
summary: Optional[str] = Field(None, description="summary") summary: Optional[str] = Field(None, description="summary")
storage_type: FileStorageType = Field(..., description="Storage type for file") meta_data: Optional[dict] = Field(None, description="Metadata information")
created_time: int = Field(..., description="create time") created_time: int = Field(..., description="create time")
topic: str = Field(..., description="topic") topic: str = Field(..., description="topic")
domain: str = Field(..., description="domain") domain: str = Field(..., description="domain")
keywords: list[str] = Field(..., description="keywords") keywords: list[str] = Field(..., description="keywords")

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from uuid import UUID
from enum import Enum from enum import Enum
@@ -9,7 +10,7 @@ class OptimizationStrategy(str, Enum):
ACCURACY_FIRST = "accuracy_first" ACCURACY_FIRST = "accuracy_first"
BALANCED = "balanced" BALANCED = "balanced"
class Memory_Reflection(BaseModel): class Memory_Reflection(BaseModel):
config_id: Optional[int] = None config_id: Optional[UUID] = None
reflection_enabled: bool reflection_enabled: bool
reflection_period_in_hours: str reflection_period_in_hours: str
reflexion_range: Optional[str] = "partial" reflexion_range: Optional[str] = "partial"

View File

@@ -1,5 +1,5 @@
""" """
所有的内容是放错误地方了应该放在models
""" """
from typing import Any, Optional, List, Dict, Literal, Union from typing import Any, Optional, List, Dict, Literal, Union
@@ -8,20 +8,8 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================
# 原 UserInput 相关 Schema (保留原有功能)
# ============================================================================
class UserInput(BaseModel):
message: str
history: list[dict]
search_switch: str
group_id: str
class Write_UserInput(BaseModel):
message: str
group_id: str
# ============================================================================ # ============================================================================
# 从 json_schema.py 迁移的 Schema # 从 json_schema.py 迁移的 Schema
@@ -159,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row # Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field("config_id", description="配置唯一标识(字符串") config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str = Field("user_id", description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
@@ -250,17 +238,17 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id: int = Field("配置ID", description="配置ID字符串") config_id: uuid.UUID = Field("配置ID", description="配置IDUUID")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
config_name: str = Field("配置名称", description="配置名称(字符串)") config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
@@ -327,14 +315,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型 # 遗忘引擎配置参数更新模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5") lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id: int = Field(..., description="配置ID唯一") config_id: uuid.UUID = Field(..., description="配置ID唯一")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -342,7 +330,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型 class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
user_id: Optional[str] = None user_id: Optional[str] = None
apply_id: Optional[str] = None apply_id: Optional[str] = None
@@ -418,7 +406,7 @@ class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型""" """遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID") config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: float = Field(..., description="衰减常数 d") decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数") lambda_time: float = Field(..., description="时间衰减参数")
lambda_mem: float = Field(..., description="记忆衰减参数") lambda_mem: float = Field(..., description="记忆衰减参数")
@@ -436,7 +424,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型""" """遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID") config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -511,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1") importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天") days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[int] = Field(None, description="配置ID可选如果为None则使用默认配置") config_id: Optional[uuid.UUID] = Field(None, description="配置ID可选如果为None则使用默认配置")
class ForgettingCurveResponse(BaseModel): class ForgettingCurveResponse(BaseModel):

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, field_serializer, ConfigDict from pydantic import BaseModel, Field, field_serializer, field_validator, ConfigDict
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import datetime import datetime
import uuid import uuid
@@ -91,6 +91,18 @@ class ModelApiKey(ModelApiKeyBase):
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None

View File

@@ -1,7 +1,7 @@
import uuid import uuid
import datetime import datetime
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, ConfigDict, field_serializer from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
# ---------- Input Schemas ---------- # ---------- Input Schemas ----------
@@ -88,6 +88,18 @@ class SharedReleaseInfo(BaseModel):
# 嵌入配置 # 嵌入配置
allow_embed: bool allow_embed: bool
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v if v is not None else {}
class EmbedCode(BaseModel): class EmbedCode(BaseModel):
"""嵌入代码""" """嵌入代码"""

View File

@@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
try: try:
memory_content = asyncio.run( memory_content = asyncio.run(
MemoryAgentService().read_memory( MemoryAgentService().read_memory(
group_id=end_user_id, end_user_id=end_user_id,
message=question, message=question,
history=[], history=[],
search_switch="2", search_switch="2",

View File

@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询 # 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags( tags = await self.emotion_repo.get_emotion_tags(
group_id=end_user_id, end_user_id=end_user_id,
emotion_type=emotion_type, emotion_type=emotion_type,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询 # 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud( keywords = await self.emotion_repo.get_emotion_wordcloud(
group_id=end_user_id, end_user_id=end_user_id,
emotion_type=emotion_type, emotion_type=emotion_type,
limit=limit limit=limit
) )
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
# 获取时间范围内的情绪数据 # 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range( emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id, end_user_id=end_user_id,
time_range=time_range time_range=time_range
) )
@@ -505,7 +505,7 @@ class EmotionAnalyticsService:
) )
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=int(config_id), config_id=(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions" service_name="EmotionAnalyticsService.generate_emotion_suggestions"
) )
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
# 3. 获取情绪数据用于模式分析 # 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range( emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id, end_user_id=end_user_id,
time_range="30d" time_range="30d"
) )
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
# 查询用户的实体和标签 # 查询用户的实体和标签
query = """ query = """
MATCH (e:Entity) MATCH (e:Entity)
WHERE e.group_id = $group_id WHERE e.end_user_id = $end_user_id
RETURN e.name as name, e.type as type RETURN e.name as name, e.type as type
ORDER BY e.created_at DESC ORDER BY e.created_at DESC
LIMIT 20 LIMIT 20
""" """
entities = await connector.execute_query(query, group_id=end_user_id) entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签 # 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]

View File

@@ -8,9 +8,11 @@ Classes:
""" """
from typing import Dict, Any from typing import Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models.data_config_model import DataConfig from app.models.memory_config_model import MemoryConfig
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
logger = get_business_logger() logger = get_business_logger()
@@ -37,7 +39,7 @@ class EmotionConfigService:
self.db = db self.db = db
logger.info("情绪配置服务初始化完成") logger.info("情绪配置服务初始化完成")
def get_emotion_config(self, config_id: int) -> Dict[str, Any]: def get_emotion_config(self, config_id: UUID) -> Dict[str, Any]:
"""获取情绪引擎配置 """获取情绪引擎配置
查询指定配置ID的情绪相关配置字段。 查询指定配置ID的情绪相关配置字段。
@@ -61,8 +63,8 @@ class EmotionConfigService:
logger.info(f"获取情绪配置: config_id={config_id}") logger.info(f"获取情绪配置: config_id={config_id}")
# 查询配置 # 查询配置
config = self.db.query(DataConfig).filter( config = self.db.query(MemoryConfig).filter(
DataConfig.config_id == config_id MemoryConfig.config_id == config_id
).first() ).first()
if not config: if not config:
@@ -144,7 +146,7 @@ class EmotionConfigService:
def update_emotion_config( def update_emotion_config(
self, self,
config_id: int, config_id: UUID,
config_data: Dict[str, Any] config_data: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""更新情绪引擎配置 """更新情绪引擎配置
@@ -173,8 +175,8 @@ class EmotionConfigService:
self.validate_emotion_config(config_data) self.validate_emotion_config(config_data)
# 查询配置 # 查询配置
config = self.db.query(DataConfig).filter( config = self.db.query(MemoryConfig).filter(
DataConfig.config_id == config_id MemoryConfig.config_id == config_id
).first() ).first()
if not config: if not config:

Some files were not shown because too many files have changed in this diff Show More