Fix/memory bug fix (#171)
This commit is contained in:
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -32,11 +33,11 @@ router = APIRouter(
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
config_id: UUID = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
@@ -53,7 +53,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
@@ -63,7 +63,7 @@ async def get_emotion_tags(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
@@ -73,7 +73,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
@@ -84,7 +84,7 @@ async def get_emotion_tags(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -159,21 +159,21 @@ async def get_emotion_health(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
@@ -186,7 +186,7 @@ async def get_emotion_health(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
||||
extra={"group_id": request.group_id}
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
@@ -137,7 +137,7 @@ async def get_preference_tags(
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
@@ -146,20 +146,20 @@ async def get_preference_tags(
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -192,17 +192,17 @@ async def get_preference_tags(
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
@@ -309,7 +309,7 @@ async def get_behavior_habits(
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
@@ -317,20 +317,20 @@ async def get_behavior_habits(
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -368,11 +368,11 @@ async def get_behavior_habits(
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ async def write_server(
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
@@ -160,19 +160,18 @@ async def write_server(
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
messages_list, # 传递结构化消息列表
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -196,7 +195,7 @@ async def write_server_async(
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
@@ -226,10 +225,10 @@ async def write_server_async(
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -255,7 +254,7 @@ async def read_server(
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
user_input: Read request with message, history, search_switch, and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
@@ -277,12 +276,13 @@ async def read_server(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
@@ -293,12 +293,12 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
group_id=user_input.group_id,
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
@@ -404,7 +404,7 @@ async def read_server_async(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
@@ -448,7 +448,7 @@ async def get_read_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -525,7 +525,7 @@ async def get_write_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -579,16 +579,16 @@ async def status_type(
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
user_input: Request containing user message and end_user_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
@@ -596,11 +596,11 @@ async def status_type(
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
@@ -625,7 +625,7 @@ async def get_knowledge_type_stats_api(
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
@@ -698,7 +698,7 @@ async def get_user_profile_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取工作空间下Popular Memory Tags,包含:
|
||||
获取用户详情,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -106,7 +107,7 @@ async def trigger_forgetting_cycle(
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -236,7 +237,7 @@ async def update_forgetting_config(
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -246,7 +247,7 @@ async def get_forgetting_stats(
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
end_user_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
@@ -260,20 +261,20 @@ async def get_forgetting_stats(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
@@ -283,14 +284,14 @@ async def get_forgetting_stats(
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
|
||||
@@ -27,27 +27,27 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
count_stats = service.get_memory_count(end_user_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
@@ -57,37 +57,37 @@ def get_memory_count(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
@@ -101,37 +101,37 @@ def get_last_visual_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
@@ -145,38 +145,38 @@ def get_last_memory_listen(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
text_memory = service.get_latest_text_memory(end_user_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
@@ -190,16 +190,16 @@ def get_last_text_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
@@ -209,7 +209,7 @@ def get_memory_time_line(
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
@@ -221,7 +221,7 @@ def get_memory_time_line(
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -232,7 +232,7 @@ def get_memory_time_line(
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
timeline_data = service.get_time_line(end_user_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
@@ -246,7 +246,7 @@ def get_memory_time_line(
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
@@ -11,7 +12,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
@@ -50,7 +51,7 @@ async def save_reflection_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
data_config = DataConfigRepository.update_reflection_config(
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
@@ -63,17 +64,17 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(data_config)
|
||||
db.refresh(memory_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": data_config.config_id,
|
||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||
"iteration_period": data_config.iteration_period,
|
||||
"reflexion_range": data_config.reflexion_range,
|
||||
"baseline": data_config.baseline,
|
||||
"reflection_model_id": data_config.reflection_model_id,
|
||||
"memory_verify": data_config.memory_verify,
|
||||
"quality_assessment": data_config.quality_assessment}
|
||||
"config_id": memory_config.config_id,
|
||||
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||
"iteration_period": memory_config.iteration_period,
|
||||
"reflexion_range": memory_config.reflexion_range,
|
||||
"baseline": memory_config.baseline,
|
||||
"reflection_model_id": memory_config.reflection_model_id,
|
||||
"memory_verify": memory_config.memory_verify,
|
||||
"quality_assessment": memory_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -111,14 +112,14 @@ async def start_workspace_reflection(
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
if data['memory_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
for base, config, user in zip(releases, memory_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
@@ -156,14 +157,14 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
config_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
@@ -191,7 +192,7 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -200,8 +201,8 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -160,7 +161,7 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
config_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -232,7 +233,7 @@ def update_config_extracted(
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
config_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
|
||||
@@ -20,18 +20,18 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -39,7 +39,7 @@ def get_conversations(
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
end_user_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_conversations(
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
end_user_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
@@ -63,7 +63,7 @@ def get_conversations(
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -100,7 +100,7 @@ def get_messages(
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -135,27 +135,27 @@ async def generate_cache_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
end_user_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||
)
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
@@ -175,9 +175,9 @@ async def generate_cache_api(
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
@@ -155,13 +155,13 @@ class LangChainAgent:
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# end_user_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
@@ -179,7 +179,7 @@ class LangChainAgent:
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
@@ -188,7 +188,7 @@ class LangChainAgent:
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
@@ -204,20 +204,20 @@ class LangChainAgent:
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
@@ -228,7 +228,7 @@ class LangChainAgent:
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
|
||||
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
kb_config = await rag_config(state)
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
group_id=state.get('group_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取group_id
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
group_id = state.get("group_id", '')
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
group_id = state.get("group_id", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=group_id,
|
||||
user_id=end_user_id,
|
||||
query=data,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
apply_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
group_id=state.get("group_id", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
|
||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, and memory_config
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
@@ -79,7 +79,7 @@ async def make_read_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
group_id = '88a459f5_text09' # 组ID
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
@@ -95,9 +95,9 @@ async def main():
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
|
||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(group_id: str):
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_group_id = group_id_param or group_id
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
group_id=actual_group_id,
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含group_id, limit, include等
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id or search_params.get("group_id"),
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
@@ -26,9 +27,21 @@ async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
@@ -42,7 +55,7 @@ async def make_write_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
group_id = 'new_2025test1103' # 组ID
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
# 获取数据库会话
|
||||
@@ -54,9 +67,9 @@ async def main():
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
|
||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
|
||||
@@ -91,7 +91,7 @@ class SearchService:
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
@@ -105,7 +105,7 @@ class SearchService:
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
end_user_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
@@ -130,7 +130,7 @@ class SearchService:
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
@@ -186,7 +186,7 @@ class SearchService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
|
||||
"""
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
return [dialog_data]
|
||||
|
||||
@@ -13,13 +13,11 @@ class WriteState(TypedDict):
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data:str
|
||||
data: str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
@@ -29,7 +27,7 @@ class ReadState(TypedDict):
|
||||
messages: 消息列表,支持自动追加
|
||||
loop_count: 遍历次数
|
||||
search_switch: 搜索类型开关
|
||||
group_id: 组标识
|
||||
end_user_id: 组标识
|
||||
config_id: 配置ID,用于过滤结果
|
||||
data: 从content_input_node传递的内容数据
|
||||
spit_data: 从Split_The_Problem传递的分解结果
|
||||
@@ -40,7 +38,7 @@ class ReadState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||
loop_count: int
|
||||
search_switch: str
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
|
||||
@@ -28,7 +28,7 @@ class RedisSessionStore:
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
@@ -46,7 +46,7 @@ class RedisSessionStore:
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
@@ -67,7 +67,7 @@ class RedisSessionStore:
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
@@ -83,7 +83,7 @@ class RedisSessionStore:
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"group_id": session.get('group_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
@@ -108,9 +108,9 @@ class RedisSessionStore:
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
@@ -124,7 +124,7 @@ class RedisSessionStore:
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
@@ -172,7 +172,7 @@ class RedisSessionStore:
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
"""
|
||||
import time
|
||||
@@ -202,12 +202,12 @@ class RedisSessionStore:
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
group_id = data.get('group_id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
@@ -248,9 +248,9 @@ class RedisSessionStore:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
@@ -276,7 +276,7 @@ class RedisSessionStore:
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -29,20 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
@@ -51,14 +49,14 @@ async def write(
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
logger.info(f"end_user_id ID: {end_user_id}")
|
||||
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
@@ -83,9 +81,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
|
||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
end_user_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if not config_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
|
||||
else:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
id=end_user_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
# 验证end_user_id必须提供且不为空
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||
end_user_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
|
||||
@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
||||
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
@@ -50,7 +50,7 @@ WITH d,
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.end_user_id AS end_user_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
|
||||
@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
end_user_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_group_id)
|
||||
return await main_from_contexts(contexts, args.context_end_user_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
||||
async def main_from_contexts(contexts: List[str], end_user_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
|
||||
@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_end_user_id,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
group_id: Database group ID for retrieval (uses default if None)
|
||||
end_user_id: Database group ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default group_id if not provided
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Use default end_user_id if not provided
|
||||
end_user_id = end_user_id or SELECTED_end_user_id
|
||||
|
||||
# Determine data path
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" Group ID: {group_id}")
|
||||
print(f" Group ID: {end_user_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
|
||||
# Step 2: Extract conversations and ingest if needed
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" Group ID: {group_id}\n")
|
||||
print(f" Group ID: {end_user_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
||||
print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
@@ -467,7 +467,7 @@ def main():
|
||||
help="Number of QA pairs to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
"--end_user_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database group ID for retrieval (uses default if not specified)"
|
||||
@@ -516,7 +516,7 @@ def main():
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
|
||||
@@ -556,7 +556,7 @@ async def run_enhanced_evaluation():
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
group_id="locomo_sk",
|
||||
end_user_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
|
||||
@@ -348,7 +348,7 @@ def select_and_format_information(
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
group_id: Database group ID (identifies which conversation memory to search)
|
||||
end_user_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
group_id: Target group ID for database storage
|
||||
end_user_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
|
||||
@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
group_id: str | None = None,
|
||||
end_user_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
@@ -262,7 +262,7 @@ async def run_locomo_eval(
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
end_user_id = end_user_id or SELECTED_end_user_id
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
@@ -340,7 +340,7 @@ async def run_locomo_eval(
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
@@ -405,7 +405,7 @@ async def run_locomo_eval(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
@@ -456,7 +456,7 @@ async def run_locomo_eval(
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
@@ -486,7 +486,7 @@ async def run_locomo_eval(
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
@@ -524,7 +524,7 @@ async def run_locomo_eval(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
@@ -597,7 +597,7 @@ async def run_locomo_eval(
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"group_id": d.get("group_id", ""),
|
||||
"end_user_id": d.get("end_user_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
@@ -795,7 +795,7 @@ async def run_locomo_eval(
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
@@ -825,7 +825,7 @@ async def run_locomo_eval(
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
@@ -841,7 +841,7 @@ def main():
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -524,11 +524,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
|
||||
|
||||
|
||||
# 通过别名匹配进行实体关键词检索(多token合并)
|
||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||
results: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for tok in tokens:
|
||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
|
||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
|
||||
if rows:
|
||||
results.extend(rows)
|
||||
except Exception:
|
||||
@@ -548,15 +548,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
|
||||
# 通过对话/陈述中的entity_ids反查实体名称
|
||||
_FETCH_ENTITIES_BY_IDS = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
||||
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||
"""
|
||||
|
||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
|
||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
|
||||
if not ids:
|
||||
return []
|
||||
try:
|
||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
|
||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
|
||||
return rows or []
|
||||
except Exception:
|
||||
return []
|
||||
@@ -566,18 +566,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
|
||||
_TIME_ENTITY_SEARCH = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
||||
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""专门搜索时间相关的实体"""
|
||||
try:
|
||||
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
||||
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
||||
date_pattern=date_pattern,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit)
|
||||
return rows or []
|
||||
except Exception:
|
||||
@@ -624,7 +624,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
|
||||
|
||||
async def run_longmemeval_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str = "longmemeval_zh_bak_3",
|
||||
end_user_id: str = "longmemeval_zh_bak_3",
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
@@ -678,13 +678,13 @@ async def run_longmemeval_test(
|
||||
contexts.extend(selected)
|
||||
|
||||
print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
|
||||
if reset_group_before_ingest and group_id:
|
||||
if reset_group_before_ingest and end_user_id:
|
||||
try:
|
||||
_tmp_conn = Neo4jConnector()
|
||||
await _tmp_conn.delete_group(group_id)
|
||||
print(f"🧹 已清空组 {group_id} 的历史图数据")
|
||||
await _tmp_conn.delete_group(end_user_id)
|
||||
print(f"🧹 已清空组 {end_user_id} 的历史图数据")
|
||||
except Exception as _e:
|
||||
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}")
|
||||
print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
|
||||
finally:
|
||||
try:
|
||||
await _tmp_conn.close()
|
||||
@@ -696,7 +696,7 @@ async def run_longmemeval_test(
|
||||
else:
|
||||
await _ingest_fn(
|
||||
contexts,
|
||||
group_id,
|
||||
end_user_id,
|
||||
save_chunk_output=save_chunk_output,
|
||||
save_chunk_output_path=save_chunk_output_path,
|
||||
)
|
||||
@@ -751,7 +751,7 @@ async def run_longmemeval_test(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
@@ -796,7 +796,7 @@ async def run_longmemeval_test(
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
@@ -831,7 +831,7 @@ async def run_longmemeval_test(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
@@ -849,7 +849,7 @@ async def run_longmemeval_test(
|
||||
kw_res = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
)
|
||||
if isinstance(kw_res, dict):
|
||||
@@ -860,7 +860,7 @@ async def run_longmemeval_test(
|
||||
# 时间推理问题的特殊处理
|
||||
if is_temporal:
|
||||
# 专门搜索时间实体
|
||||
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
|
||||
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
|
||||
if time_entities:
|
||||
kw_entities.extend(time_entities)
|
||||
# 添加时间相关关键词检索
|
||||
@@ -870,7 +870,7 @@ async def run_longmemeval_test(
|
||||
time_res = await search_graph(
|
||||
connector=connector,
|
||||
q=tk,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=2,
|
||||
)
|
||||
if isinstance(time_res, dict):
|
||||
@@ -881,7 +881,7 @@ async def run_longmemeval_test(
|
||||
|
||||
# 中文关键词拆分后做别名匹配
|
||||
cn_tokens = _extract_cn_tokens(question)
|
||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
|
||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
|
||||
if alias_entities:
|
||||
kw_entities.extend(alias_entities)
|
||||
|
||||
@@ -895,7 +895,7 @@ async def run_longmemeval_test(
|
||||
except Exception:
|
||||
pass
|
||||
if ids:
|
||||
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
|
||||
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
|
||||
if id_entities:
|
||||
kw_entities.extend(id_entities)
|
||||
|
||||
@@ -909,7 +909,7 @@ async def run_longmemeval_test(
|
||||
sub_res = await search_graph(
|
||||
connector=connector,
|
||||
q=str(kw),
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(3, search_limit // 2),
|
||||
)
|
||||
if isinstance(sub_res, dict):
|
||||
@@ -928,7 +928,7 @@ async def run_longmemeval_test(
|
||||
opt_res = await search_graph(
|
||||
connector=connector,
|
||||
q=str(opt),
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(3, search_limit // 2),
|
||||
)
|
||||
if isinstance(opt_res, dict):
|
||||
@@ -1010,7 +1010,7 @@ async def run_longmemeval_test(
|
||||
kw_fallback = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(search_limit, 5),
|
||||
)
|
||||
fb_dialogs = kw_fallback.get("dialogues", []) or []
|
||||
@@ -1224,7 +1224,7 @@ async def run_longmemeval_test(
|
||||
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
|
||||
},
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
@@ -1307,7 +1307,7 @@ def main():
|
||||
result = asyncio.run(
|
||||
run_longmemeval_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
|
||||
|
||||
|
||||
# 通过别名匹配进行实体关键词检索(多token合并)
|
||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||
results: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for tok in tokens:
|
||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
|
||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
|
||||
if rows:
|
||||
results.extend(rows)
|
||||
except Exception:
|
||||
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
|
||||
# 通过对话/陈述中的entity_ids反查实体名称
|
||||
_FETCH_ENTITIES_BY_IDS = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
||||
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||
"""
|
||||
|
||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
|
||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
|
||||
if not ids:
|
||||
return []
|
||||
try:
|
||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
|
||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
|
||||
return rows or []
|
||||
except Exception:
|
||||
return []
|
||||
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
|
||||
_TIME_ENTITY_SEARCH = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
||||
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""专门搜索时间相关的实体"""
|
||||
try:
|
||||
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
||||
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
||||
date_pattern=date_pattern,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit)
|
||||
return rows or []
|
||||
except Exception:
|
||||
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
|
||||
|
||||
|
||||
# 技术术语专门检索
|
||||
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
|
||||
async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
|
||||
"""专门搜索技术术语相关的实体"""
|
||||
tech_entities = []
|
||||
try:
|
||||
# GPS相关
|
||||
if any(term in question for term in ["GPS", "导航", "定位系统"]):
|
||||
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit)
|
||||
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
|
||||
if gps_rows:
|
||||
tech_entities.extend(gps_rows)
|
||||
|
||||
# 活动相关
|
||||
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
|
||||
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit)
|
||||
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
|
||||
if workshop_rows:
|
||||
tech_entities.extend(workshop_rows)
|
||||
|
||||
# 时间顺序相关
|
||||
if any(term in question for term in ["先", "后", "第一个"]):
|
||||
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit)
|
||||
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
|
||||
if time_rows:
|
||||
tech_entities.extend(time_rows)
|
||||
|
||||
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
|
||||
|
||||
async def run_longmemeval_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str = "longmemeval_zh_bak_2",
|
||||
end_user_id: str = "longmemeval_zh_bak_2",
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
)
|
||||
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
)
|
||||
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
|
||||
kw_res = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
)
|
||||
if isinstance(kw_res, dict):
|
||||
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
|
||||
kw_entities = kw_res.get("entities", []) or []
|
||||
|
||||
# 技术术语专门检索
|
||||
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2)
|
||||
tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
|
||||
if tech_entities:
|
||||
kw_entities.extend(tech_entities)
|
||||
|
||||
# 时间推理问题的特殊处理
|
||||
if is_temporal:
|
||||
# 专门搜索时间实体
|
||||
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
|
||||
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
|
||||
if time_entities:
|
||||
kw_entities.extend(time_entities)
|
||||
# 添加时间相关关键词检索
|
||||
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
|
||||
time_res = await search_graph(
|
||||
connector=connector,
|
||||
q=tk,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=2,
|
||||
)
|
||||
if isinstance(time_res, dict):
|
||||
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
|
||||
|
||||
# 中文关键词拆分后做别名匹配
|
||||
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
|
||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
|
||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
|
||||
if alias_entities:
|
||||
kw_entities.extend(alias_entities)
|
||||
|
||||
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
|
||||
except Exception:
|
||||
pass
|
||||
if ids:
|
||||
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
|
||||
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
|
||||
if id_entities:
|
||||
kw_entities.extend(id_entities)
|
||||
|
||||
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
|
||||
sub_res = await search_graph(
|
||||
connector=connector,
|
||||
q=str(kw),
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(3, search_limit // 2),
|
||||
)
|
||||
if isinstance(sub_res, dict):
|
||||
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
|
||||
opt_res = await search_graph(
|
||||
connector=connector,
|
||||
q=str(opt),
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(3, search_limit // 2),
|
||||
)
|
||||
if isinstance(opt_res, dict):
|
||||
@@ -971,7 +971,7 @@ async def run_longmemeval_test(
|
||||
kw_fallback = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(search_limit, 5),
|
||||
)
|
||||
fb_dialogs = kw_fallback.get("dialogues", []) or []
|
||||
@@ -1199,7 +1199,7 @@ async def run_longmemeval_test(
|
||||
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
|
||||
},
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
@@ -1278,7 +1278,7 @@ def main():
|
||||
result = asyncio.run(
|
||||
run_longmemeval_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
with get_db_context() as db:
|
||||
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
@@ -298,7 +298,7 @@ def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
@@ -309,7 +309,7 @@ def main():
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -199,7 +199,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str | None = None,
|
||||
end_user_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
@@ -217,7 +217,7 @@ async def run_memsciqa_test(
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
group_id = group_id or "group_memsci"
|
||||
end_user_id = end_user_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
if not data_path:
|
||||
@@ -283,7 +283,7 @@ async def run_memsciqa_test(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
@@ -292,7 +292,7 @@ async def run_memsciqa_test(
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
@@ -500,7 +500,7 @@ async def run_memsciqa_test(
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
@@ -543,7 +543,7 @@ def main():
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -26,7 +26,7 @@ async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
group_id: str | None,
|
||||
end_user_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
@@ -37,17 +37,17 @@ async def run(
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(group_id)
|
||||
await connector.delete_group(end_user_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
@@ -61,7 +61,7 @@ async def run(
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
@@ -75,7 +75,7 @@ async def run(
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
@@ -99,8 +99,8 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
@@ -117,7 +117,7 @@ def main():
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.group_id,
|
||||
args.end_user_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
|
||||
@@ -187,11 +187,11 @@ class ChunkerClient:
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
@@ -201,9 +201,9 @@ class ChunkerClient:
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
@@ -212,13 +212,13 @@ class ChunkerClient:
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
@@ -228,14 +228,14 @@ class ChunkerClient:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
@@ -260,7 +260,7 @@ class ChunkerClient:
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
@@ -268,7 +268,7 @@ class ChunkerClient:
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
|
||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
||||
"""Parameters for temporal search queries in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
group_id: Group ID to filter search results (default: 'test')
|
||||
end_user_id: Group ID to filter search results (default: 'test')
|
||||
apply_id: Application ID to filter search results
|
||||
user_id: User ID to filter search results
|
||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||
limit: Maximum number of results to return (default: 3)
|
||||
"""
|
||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||
|
||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
||||
id: Unique identifier for the edge
|
||||
source: ID of the source node
|
||||
target: ID of the target node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
target: str = Field(..., description="The ID of the target node.")
|
||||
group_id: str = Field(..., description="The group ID of the edge.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the node
|
||||
name: Name of the node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
group_id: str = Field(..., description="The group ID of the node.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
|
||||
@@ -55,7 +55,7 @@ class Statement(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the statement
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
end_user_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
@@ -73,7 +73,7 @@ class Statement(BaseModel):
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
|
||||
context: Full conversation context
|
||||
dialog_embedding: Optional embedding vector for the entire dialog
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
|
||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
|
||||
return []
|
||||
|
||||
def assign_group_id_to_statements(self) -> None:
|
||||
"""Assign this dialog's group_id to all statements in all chunks.
|
||||
"""Assign this dialog's end_user_id to all statements in all chunks.
|
||||
|
||||
This method updates statements that don't have a group_id set.
|
||||
This method updates statements that don't have a end_user_id set.
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.group_id is None:
|
||||
statement.group_id = self.group_id
|
||||
if statement.end_user_id is None:
|
||||
statement.end_user_id = self.end_user_id
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -396,13 +397,13 @@ def rerank_with_activation(
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
search_type: Type of search (keyword, embedding, hybrid)
|
||||
group_id: Group identifier for filtering
|
||||
end_user_id: Group identifier for filtering
|
||||
limit: Maximum number of results
|
||||
include: List of result types to include
|
||||
log_file: Deprecated parameter, kept for backward compatibility
|
||||
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
f"group_id={group_id}, limit={limit}, include={include}"
|
||||
f"end_user_id={end_user_id}, limit={limit}, include={include}"
|
||||
)
|
||||
|
||||
|
||||
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: str | None,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
@@ -715,7 +716,7 @@ async def run_hybrid_search(
|
||||
}
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, group_id, limit, include)
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
results = {}
|
||||
@@ -732,7 +733,7 @@ async def run_hybrid_search(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
)
|
||||
@@ -769,7 +770,7 @@ async def run_hybrid_search(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
@@ -916,9 +917,7 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -929,7 +928,7 @@ async def search_by_temporal(
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id
|
||||
- Optionally filters by end_user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
@@ -939,9 +938,7 @@ async def search_by_temporal(
|
||||
end_date = normalize_date_safe(end_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -950,9 +947,7 @@ async def search_by_temporal(
|
||||
})
|
||||
statements = await search_graph_by_temporal(
|
||||
connector=connector,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -964,9 +959,7 @@ async def search_by_temporal(
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
|
||||
invalid_date = normalize_date_safe(invalid_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
|
||||
statements = await search_graph_by_keyword_temporal(
|
||||
connector=connector,
|
||||
query_text=query_text,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
group_id: Optional[str] = "test",
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
|
||||
chunks = await search_graph_by_chunk_id(
|
||||
connector=connector,
|
||||
chunk_id=chunk_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
@@ -555,8 +555,8 @@ class DataPreprocessor:
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
|
||||
|
||||
# 获取group_id,如果不存在则生成默认值
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
# 获取end_user_id,如果不存在则生成默认值
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -574,7 +574,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
@@ -644,7 +644,7 @@ class DataPreprocessor:
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -657,7 +657,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
|
||||
@@ -199,7 +199,7 @@ def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
@@ -210,8 +210,8 @@ def accurate_match(
|
||||
for ent in entity_nodes:
|
||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
||||
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
|
||||
if key not in canonical_map:
|
||||
canonical_map[key] = ent
|
||||
id_redirect[ent.id] = ent.id
|
||||
@@ -223,11 +223,11 @@ def accurate_match(
|
||||
id_redirect[ent.id] = canonical.id
|
||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||
try:
|
||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": canonical.id,
|
||||
"group_id": canonical.group_id,
|
||||
"end_user_id": canonical.end_user_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
@@ -596,7 +596,7 @@ def fuzzy_match(
|
||||
b = deduped_entities[j]
|
||||
|
||||
# 跳过不同业务组的实体
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
@@ -671,7 +671,7 @@ def fuzzy_match(
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
fuzzy_merge_records.append(
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
|
||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||
)
|
||||
except Exception:
|
||||
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
# 记录 LLM 融合日志
|
||||
try:
|
||||
llm_records.append(
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||
except Exception:
|
||||
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
|
||||
id_redirect[k] = a.id
|
||||
try:
|
||||
disamb_records.append(
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -174,7 +174,7 @@ async def _judge_pair(
|
||||
pass
|
||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
|
||||
except Exception:
|
||||
pass
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
"name_embed_sim": name_embed_sim,
|
||||
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
a = entity_nodes[i]
|
||||
for j in range(i + 1, len(entity_nodes)):
|
||||
b = entity_nodes[j]
|
||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
# 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复)
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
- max_rounds: upper bound for iterative passes (default 3)
|
||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
||||
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
|
||||
|
||||
Returns:
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
|
||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||
"""
|
||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
|
||||
Args:
|
||||
nodes: 实体节点列表
|
||||
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
"""
|
||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||
for e in nodes:
|
||||
gid = getattr(e, "group_id", None)
|
||||
gid = getattr(e, "end_user_id", None)
|
||||
groups.setdefault(str(gid), []).append(e)
|
||||
blocks: List[List[ExtractedEntityNode]] = []
|
||||
for gid, arr in groups.items():
|
||||
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||
current_nodes = _collapse_nodes(current_nodes)
|
||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
||||
# 步骤2:分块(按end_user_id分块,避免跨组处理)
|
||||
blocks = _partition_blocks(current_nodes)
|
||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||
break
|
||||
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
# 必须同组
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
ta = getattr(a, "entity_type", None)
|
||||
tb = getattr(b, "entity_type", None)
|
||||
|
||||
@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
return ExtractedEntityNode(
|
||||
id=row.get("id"),
|
||||
name=row.get("name") or "",
|
||||
group_id=row.get("group_id") or "",
|
||||
end_user_id=row.get("end_user_id") or "",
|
||||
user_id=row.get("user_id") or "",
|
||||
apply_id=row.get("apply_id") or "",
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
|
||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||
connector: Neo4jConnector,
|
||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
||||
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
|
||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||
"""
|
||||
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
|
||||
]
|
||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||
connector=connector, group_id=group_id,
|
||||
connector=connector, end_user_id=end_user_id,
|
||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||
)
|
||||
|
||||
@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
if pipeline_config is None:
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
# 先探测 end_user_id,决定报告写入策略
|
||||
end_user_id: Optional[str] = None
|
||||
for dd in dialog_data_list:
|
||||
group_id = getattr(dd, "group_id", None)
|
||||
if group_id:
|
||||
end_user_id = getattr(dd, "end_user_id", None)
|
||||
if end_user_id:
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
|
||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
if connector:
|
||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
entity_nodes=dedup_entity_nodes,
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
else:
|
||||
print("Skip second-layer dedup: missing group_id")
|
||||
print("Skip second-layer dedup: missing end_user_id")
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
||||
for c_idx, chunk in enumerate(dialog.chunks):
|
||||
all_chunks.append((chunk, dialog.group_id, dialogue_content))
|
||||
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
chunk, end_user_id, dialogue_content = chunk_data
|
||||
try:
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
# 加载MemoryConfig
|
||||
memory_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
if memory_config and not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
if not memory_config or not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
|
||||
total_statements += 1
|
||||
# 只处理用户的陈述句 (role 为 "user")
|
||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||
all_statements.append((statement, data_config))
|
||||
all_statements.append((statement, memory_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
filtered_statements += 1
|
||||
|
||||
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
|
||||
id=dialog_data.id,
|
||||
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
||||
ref_id=dialog_data.ref_id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
|
||||
id=chunk.id,
|
||||
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
||||
dialog_id=dialog_data.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
|
||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edge = StatementChunkEdge(
|
||||
source=statement.id,
|
||||
target=chunk.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1095,9 +1087,7 @@ class ExtractionOrchestrator:
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1112,9 +1102,7 @@ class ExtractionOrchestrator:
|
||||
source=statement.id,
|
||||
target=entity.id,
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1134,9 +1122,7 @@ class ExtractionOrchestrator:
|
||||
relation_type=triplet.predicate,
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1763,14 +1749,14 @@ class ExtractionOrchestrator:
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
end_user_id: str = "group_1",
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> List[DialogData]:
|
||||
"""从测试数据生成分块对话
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
indices: 要处理的数据索引列表(可选)
|
||||
|
||||
Returns:
|
||||
@@ -1834,7 +1820,7 @@ async def get_chunked_dialogs(
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=data['id'],
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
metadata=dialog_metadata,
|
||||
)
|
||||
|
||||
@@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
async def get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "default",
|
||||
end_user_id: str = "default",
|
||||
user_id: str = "default",
|
||||
apply_id: str = "default",
|
||||
indices: Optional[List[int]] = None,
|
||||
@@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
indices: 要处理的数据索引列表
|
||||
@@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
# 设置 group_id, user_id, apply_id
|
||||
# 设置 end_user_id
|
||||
for dd in preprocessed_data:
|
||||
dd.group_id = group_id
|
||||
dd.user_id = user_id
|
||||
dd.apply_id = apply_id
|
||||
dd.end_user_id = end_user_id
|
||||
|
||||
# 步骤2: 语义剪枝
|
||||
try:
|
||||
|
||||
@@ -193,9 +193,9 @@ async def _process_chunk_summary(
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
user_id=dialog.end_user_id,
|
||||
apply_id=dialog.end_user_id,
|
||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||
created_at=datetime.now(),
|
||||
expired_at=datetime(9999, 12, 31),
|
||||
|
||||
@@ -82,12 +82,12 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
Args:
|
||||
chunk: Chunk object to process
|
||||
group_id: Group ID to assign to all statements in this chunk
|
||||
end_user_id: Group ID to assign to all statements in this chunk
|
||||
dialogue_content: Full dialogue content to provide as context
|
||||
|
||||
Returns:
|
||||
@@ -158,7 +158,7 @@ class StatementExtractor:
|
||||
temporal_info=temporal_type,
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
speaker=chunk_speaker,
|
||||
)
|
||||
|
||||
@@ -184,10 +184,10 @@ class StatementExtractor:
|
||||
|
||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||
|
||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
||||
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
|
||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||
results = await asyncio.gather(
|
||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
||||
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
@@ -225,7 +225,7 @@ class StatementExtractor:
|
||||
for i, statement in enumerate(statements, 1):
|
||||
f.write(f"Statement {i}:\n")
|
||||
f.write(f"Id: {statement.id}\n")
|
||||
f.write(f"Group Id: {statement.group_id}\n")
|
||||
f.write(f"Group Id: {statement.end_user_id}\n")
|
||||
f.write(f"Content: {statement.statement}\n")
|
||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||
@@ -298,7 +298,7 @@ class StatementExtractor:
|
||||
|
||||
dialog_sections.append({
|
||||
"dialog_id": dialog.ref_id,
|
||||
"group_id": dialog.group_id,
|
||||
"end_user_id": dialog.end_user_id,
|
||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||
"strong": strong_relations,
|
||||
"weak": weak_relations,
|
||||
@@ -312,7 +312,7 @@ class StatementExtractor:
|
||||
for idx, section in enumerate(dialog_sections, 1):
|
||||
f.write(f"Dialog {idx}:\n")
|
||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
|
||||
f.write("Content:\n")
|
||||
f.write(f"{section.get('content', '')}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
@@ -132,7 +132,7 @@ class TemporalExtractor:
|
||||
prompt_logger.info("")
|
||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -116,7 +116,7 @@ class TripletExtractor:
|
||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -75,7 +75,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -91,7 +91,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选,用于过滤)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
|
||||
Returns:
|
||||
@@ -123,7 +123,7 @@ class AccessHistoryManager:
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
@@ -142,7 +142,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -172,7 +172,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_ids: List[str],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -184,7 +184,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
|
||||
Returns:
|
||||
@@ -202,7 +202,7 @@ class AccessHistoryManager:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
@@ -235,7 +235,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
@@ -249,14 +249,14 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
@@ -305,7 +305,7 @@ class AccessHistoryManager:
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -313,7 +313,7 @@ class AccessHistoryManager:
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
@@ -329,16 +329,16 @@ class AccessHistoryManager:
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
"""
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
if end_user_id:
|
||||
params["end_user_id"] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
@@ -351,7 +351,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -387,7 +387,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
@@ -401,7 +401,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
@@ -411,7 +411,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -419,7 +419,7 @@ class AccessHistoryManager:
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
@@ -457,8 +457,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
SET n += $repair_data
|
||||
RETURN n
|
||||
@@ -468,8 +468,8 @@ class AccessHistoryManager:
|
||||
'node_id': node_id,
|
||||
'repair_data': repair_data
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -491,7 +491,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
@@ -499,7 +499,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
@@ -507,8 +507,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id,
|
||||
n.importance_score as importance_score,
|
||||
@@ -519,8 +519,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
params = {'node_id': node_id}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -585,7 +585,7 @@ class AccessHistoryManager:
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
update_data: Dict[str, Any],
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
@@ -597,7 +597,7 @@ class AccessHistoryManager:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
update_data: 更新数据
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
@@ -606,13 +606,13 @@ class AccessHistoryManager:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
read_query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
@@ -624,8 +624,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if group_id:
|
||||
read_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
@@ -656,8 +656,8 @@ class AccessHistoryManager:
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
@@ -695,8 +695,8 @@ class AccessHistoryManager:
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if group_id:
|
||||
update_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
@@ -720,7 +720,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,9 +11,10 @@ Functions:
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
||||
|
||||
def load_actr_config_from_db(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库加载 ACT-R 配置参数
|
||||
|
||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
||||
从 PostgreSQL 的 memory_config 表读取配置参数,
|
||||
并计算派生参数(如 forgetting_rate)。
|
||||
|
||||
Args:
|
||||
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
|
||||
|
||||
# 从数据库加载配置
|
||||
try:
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None:
|
||||
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
|
||||
|
||||
def create_actr_calculator_from_config(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> ACTRCalculator:
|
||||
"""
|
||||
从数据库配置创建 ACTRCalculator 实例
|
||||
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
|
||||
Examples:
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> db = Session()
|
||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
||||
>>> # 使用计算器
|
||||
>>> activation = calculator.calculate_memory_activation(...)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
@@ -16,6 +16,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
@@ -66,10 +67,10 @@ class ForgettingScheduler:
|
||||
|
||||
async def run_forgetting_cycle(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_merge_batch_size: int = 100,
|
||||
min_days_since_access: int = 30,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -77,7 +78,7 @@ class ForgettingScheduler:
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
@@ -107,19 +108,19 @@ class ForgettingScheduler:
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
f"开始遗忘周期: group_id={group_id}, "
|
||||
f"开始遗忘周期: end_user_id={end_user_id}, "
|
||||
f"max_batch={max_merge_batch_size}, "
|
||||
f"min_days={min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤1:统计遗忘前的节点数量
|
||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
||||
nodes_before = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||
|
||||
# 步骤2:识别可遗忘的节点对
|
||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
min_days_since_access=min_days_since_access
|
||||
)
|
||||
|
||||
@@ -213,7 +214,7 @@ class ForgettingScheduler:
|
||||
'statement_text': pair['statement_text'],
|
||||
'statement_activation': pair['statement_activation'],
|
||||
'statement_importance': pair['statement_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
entity_node = {
|
||||
@@ -222,7 +223,7 @@ class ForgettingScheduler:
|
||||
'entity_type': pair['entity_type'],
|
||||
'entity_activation': pair['entity_activation'],
|
||||
'entity_importance': pair['entity_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
# 融合节点
|
||||
@@ -262,7 +263,7 @@ class ForgettingScheduler:
|
||||
continue
|
||||
|
||||
# 步骤6:统计遗忘后的节点数量
|
||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
||||
nodes_after = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||
|
||||
# 步骤7:生成遗忘报告
|
||||
@@ -315,7 +316,7 @@ class ForgettingScheduler:
|
||||
|
||||
async def _count_knowledge_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计知识层节点总数
|
||||
@@ -323,7 +324,7 @@ class ForgettingScheduler:
|
||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
|
||||
Returns:
|
||||
int: 知识层节点总数
|
||||
@@ -333,16 +334,16 @@ class ForgettingScheduler:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
params = {}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
end_user_id['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -90,7 +91,7 @@ class ForgettingStrategy:
|
||||
|
||||
async def find_forgettable_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
min_days_since_access: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -102,7 +103,7 @@ class ForgettingStrategy:
|
||||
3. Statement 和 Entity 之间存在关系边
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
|
||||
Returns:
|
||||
@@ -136,8 +137,8 @@ class ForgettingStrategy:
|
||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN s.id as statement_id,
|
||||
@@ -159,8 +160,8 @@ class ForgettingStrategy:
|
||||
'threshold': self.forgetting_threshold,
|
||||
'cutoff_time': cutoff_time_iso
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -176,7 +177,7 @@ class ForgettingStrategy:
|
||||
self,
|
||||
statement_node: Dict[str, Any],
|
||||
entity_node: Dict[str, Any],
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -247,8 +248,8 @@ class ForgettingStrategy:
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
# 获取 end_user_id(从 statement 或 entity 节点)
|
||||
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
@@ -325,7 +326,7 @@ class ForgettingStrategy:
|
||||
last_access_time: $current_time,
|
||||
access_count: 1,
|
||||
version: 1,
|
||||
group_id: $group_id,
|
||||
end_user_id: $end_user_id,
|
||||
created_at: datetime($current_time),
|
||||
merged_at: datetime($current_time)
|
||||
})
|
||||
@@ -423,7 +424,7 @@ class ForgettingStrategy:
|
||||
'inherited_activation': inherited_activation,
|
||||
'inherited_importance': inherited_importance,
|
||||
'current_time': current_time_iso,
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -462,7 +463,7 @@ class ForgettingStrategy:
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -527,7 +528,7 @@ class ForgettingStrategy:
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
async def _get_llm_client(self, db, config_id: int):
|
||||
async def _get_llm_client(self, db, config_id: UUID):
|
||||
"""
|
||||
从数据库获取 LLM 客户端
|
||||
|
||||
@@ -539,11 +540,11 @@ class ForgettingStrategy:
|
||||
LLM 客户端实例,如果无法获取则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 从数据库读取配置
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None or db_config.llm_id is None:
|
||||
|
||||
@@ -37,7 +37,7 @@ __all__ = [
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
group_id: str | None = None,
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
@@ -54,7 +54,7 @@ async def run_hybrid_search(
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
group_id: 组ID过滤
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
@@ -104,7 +104,7 @@ async def run_hybrid_search(
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
|
||||
@@ -77,7 +77,7 @@
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# group_id: Optional[str] = None,
|
||||
# end_user_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
@@ -86,7 +86,7 @@
|
||||
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# group_id: 可选的组ID过滤
|
||||
# end_user_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
@@ -94,7 +94,7 @@
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
@@ -107,14 +107,14 @@
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
@@ -139,7 +139,7 @@
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
@@ -165,7 +165,7 @@
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
|
||||
@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
|
||||
self,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
limit: 结果限制
|
||||
**kwargs: 其他元数据
|
||||
|
||||
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
|
||||
metadata = {
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder_client,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
target_keys = [
|
||||
"id",
|
||||
"statement",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"chunk_id",
|
||||
"created_at",
|
||||
"expired_at",
|
||||
@@ -75,7 +75,7 @@ async def get_data(result):
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
|
||||
@@ -62,7 +62,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
config_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
success: bool = True,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
user_id: 用户 ID(可选)
|
||||
group_id: 组 ID(可选)
|
||||
end_user_id: 组 ID(可选)
|
||||
success: 是否成功
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"CONFIG_LOAD config_id={config_id} "
|
||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
||||
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
|
||||
f"result={result}"
|
||||
)
|
||||
if details:
|
||||
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
operation: str,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
success: bool = True,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
operation: 操作类型(WRITE, READ 等)
|
||||
config_id: 配置 ID
|
||||
group_id: 组 ID
|
||||
end_user_id: 组 ID
|
||||
success: 是否成功
|
||||
duration: 操作耗时(秒)
|
||||
error: 错误信息(可选)
|
||||
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"{operation.upper()} config_id={config_id} "
|
||||
f"group={group_id} result={result}"
|
||||
f"group={end_user_id} result={result}"
|
||||
)
|
||||
if duration is not None:
|
||||
msg += f" duration={duration:.2f}s"
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import StrEnum, auto
|
||||
class Field(StrEnum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
GROUP_KEY = "end_user_id"
|
||||
VECTOR = auto()
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = auto()
|
||||
|
||||
@@ -26,7 +26,7 @@ logger = get_config_logger()
|
||||
|
||||
|
||||
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
||||
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
"""Parse model ID from string or UUID."""
|
||||
if model_id is None:
|
||||
return None
|
||||
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate that a model exists and is active.
|
||||
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
required: bool = False,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[Optional[UUID], Optional[str]]:
|
||||
"""Validate and resolve a model ID, checking existence and active status.
|
||||
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
|
||||
|
||||
|
||||
def validate_embedding_model(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
embedding_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
@@ -256,7 +256,7 @@ def validate_embedding_model(
|
||||
|
||||
|
||||
def validate_llm_model(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
llm_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Literal
|
||||
@@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: int = Field(
|
||||
config_id: UUID = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: int = Field(
|
||||
config_id: UUID = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ class MemoryReadNode(BaseNode):
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
search_switch=self.typed_config.search_switch,
|
||||
|
||||
@@ -18,7 +18,7 @@ from .appshare_model import AppShare
|
||||
from .release_share_model import ReleaseShare
|
||||
from .conversation_model import Conversation, Message
|
||||
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
|
||||
from .data_config_model import DataConfig
|
||||
from .memory_config_model import MemoryConfig
|
||||
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||
from .retrieval_info import RetrievalInfo
|
||||
@@ -57,7 +57,7 @@ __all__ = [
|
||||
"ApiKey",
|
||||
"ApiKeyLog",
|
||||
"ApiKeyType",
|
||||
"DataConfig",
|
||||
"MemoryConfig",
|
||||
"MultiAgentConfig",
|
||||
"AgentInvocation",
|
||||
"WorkflowConfig",
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
import datetime
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class DataConfig(Base):
|
||||
"""数据配置表 - 用于存储记忆系统的配置参数"""
|
||||
__tablename__ = "data_config"
|
||||
|
||||
# 主键
|
||||
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
|
||||
|
||||
# 基本信息
|
||||
config_name = Column(String, nullable=False, comment="配置名称")
|
||||
config_desc = Column(String, nullable=True, comment="配置描述")
|
||||
|
||||
# 组织信息
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
|
||||
group_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
# 模型选择(从workspace继承)
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
|
||||
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
|
||||
|
||||
# 阈值配置 (0-1 之间的浮点数)
|
||||
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
|
||||
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
|
||||
t_overall = Column(Float, default=0.8, comment="综合阈值")
|
||||
|
||||
# 状态配置
|
||||
state = Column(Boolean, default=False, comment="配置使用状态")
|
||||
|
||||
# 分块策略
|
||||
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
|
||||
|
||||
# 剪枝配置
|
||||
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
|
||||
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound")
|
||||
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)")
|
||||
|
||||
# 自我反思配置
|
||||
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
|
||||
iteration_period = Column(String, default="3", comment="反思迭代周期")
|
||||
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
|
||||
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
|
||||
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
|
||||
memory_verify = Column(Boolean, default=True, comment="记忆验证")
|
||||
quality_assessment = Column(Boolean, default=True, comment="质量评估")
|
||||
|
||||
# 遗忘引擎配置
|
||||
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
|
||||
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
|
||||
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
|
||||
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数")
|
||||
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数")
|
||||
offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数")
|
||||
|
||||
# ACT-R 遗忘引擎配置
|
||||
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5")
|
||||
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3")
|
||||
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24")
|
||||
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True")
|
||||
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100")
|
||||
max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100")
|
||||
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30")
|
||||
|
||||
# 情绪引擎配置
|
||||
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
|
||||
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
|
||||
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
|
||||
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
|
||||
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DataConfig(config_id={self.config_id}, config_name={self.config_name})>"
|
||||
@@ -1,39 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Model - Backward Compatibility
|
||||
import datetime
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
This module provides backward compatibility for imports.
|
||||
All classes have been moved to app.schemas.memory_config_schema.
|
||||
|
||||
DEPRECATED: Import from app.schemas.memory_config_schema instead.
|
||||
"""
|
||||
class MemoryConfig(Base):
|
||||
"""记忆配置表 - 用于存储记忆系统的配置参数"""
|
||||
__tablename__ = "memory_config"
|
||||
|
||||
# Re-export for backward compatibility
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
MemoryConfigValidation,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
ModelValidation,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceValidation,
|
||||
validate_memory_config_data,
|
||||
validate_model_data,
|
||||
validate_workspace_data,
|
||||
)
|
||||
# 主键
|
||||
config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
|
||||
|
||||
__all__ = [
|
||||
"ConfigurationError",
|
||||
"InvalidConfigError",
|
||||
"MemoryConfig",
|
||||
"MemoryConfigValidation",
|
||||
"ModelInactiveError",
|
||||
"ModelNotFoundError",
|
||||
"ModelValidation",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspaceValidation",
|
||||
"validate_memory_config_data",
|
||||
"validate_model_data",
|
||||
"validate_workspace_data",
|
||||
]
|
||||
# 基本信息
|
||||
config_name = Column(String, nullable=False, comment="配置名称")
|
||||
config_desc = Column(String, nullable=True, comment="配置描述")
|
||||
|
||||
# 组织信息
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
|
||||
end_user_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
# 模型选择(从workspace继承)
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
|
||||
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
|
||||
|
||||
# 阈值配置 (0-1 之间的浮点数)
|
||||
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
|
||||
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
|
||||
t_overall = Column(Float, default=0.8, comment="综合阈值")
|
||||
|
||||
# 状态配置
|
||||
state = Column(Boolean, default=False, comment="配置使用状态")
|
||||
|
||||
# 分块策略
|
||||
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
|
||||
|
||||
# 剪枝配置
|
||||
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
|
||||
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound")
|
||||
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)")
|
||||
|
||||
# 自我反思配置
|
||||
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
|
||||
iteration_period = Column(String, default="3", comment="反思迭代周期")
|
||||
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
|
||||
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
|
||||
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
|
||||
memory_verify = Column(Boolean, default=True, comment="记忆验证")
|
||||
quality_assessment = Column(Boolean, default=True, comment="质量评估")
|
||||
|
||||
# 遗忘引擎配置
|
||||
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
|
||||
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
|
||||
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
|
||||
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数")
|
||||
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数")
|
||||
offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数")
|
||||
|
||||
# ACT-R 遗忘引擎配置
|
||||
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5")
|
||||
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3")
|
||||
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24")
|
||||
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True")
|
||||
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100")
|
||||
max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100")
|
||||
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30")
|
||||
|
||||
# 情绪引擎配置
|
||||
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
|
||||
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
|
||||
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
|
||||
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
|
||||
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MemoryConfig(config_id={self.config_id}, config_name={self.config_name})>"
|
||||
|
||||
@@ -16,7 +16,7 @@ class PerceptualType(IntEnum):
|
||||
CONVERSATION = 4
|
||||
|
||||
|
||||
class FileStorageType(IntEnum):
|
||||
class FileStorageService(IntEnum):
|
||||
LOCAL = 1
|
||||
REMOTE = 2
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据配置Repository模块
|
||||
"""记忆配置Repository模块
|
||||
|
||||
本模块提供data_config表的数据访问层,使用SQLAlchemy ORM进行数据库操作。
|
||||
本模块提供memory_config表的数据访问层,使用SQLAlchemy ORM进行数据库操作。
|
||||
包括CRUD操作和Neo4j Cypher查询常量。
|
||||
|
||||
Classes:
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作
|
||||
MemoryConfigRepository: 记忆配置仓储类,提供CRUD操作
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
@@ -28,11 +29,11 @@ db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
config_logger = get_config_logger()
|
||||
|
||||
TABLE_NAME = "data_config"
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
TABLE_NAME = "memory_config"
|
||||
class MemoryConfigRepository:
|
||||
"""记忆配置Repository
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
提供memory_config表的数据访问方法,包括:
|
||||
- SQLAlchemy ORM 数据库操作
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
@@ -41,48 +42,48 @@ class DataConfigRepository:
|
||||
|
||||
# Dialogue count by group
|
||||
SEARCH_FOR_DIALOGUE = """
|
||||
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# Chunk count by group
|
||||
SEARCH_FOR_CHUNK = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# Statement count by group
|
||||
SEARCH_FOR_STATEMENT = """
|
||||
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# ExtractedEntity count by group
|
||||
SEARCH_FOR_ENTITY = """
|
||||
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# All counts by label and total
|
||||
SEARCH_FOR_ALL = """
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
|
||||
OPTIONAL MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
|
||||
OPTIONAL MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN 'Statement' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
|
||||
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN n.entity_idx AS entity_idx,
|
||||
n.connect_strength AS connect_strength,
|
||||
n.description AS description,
|
||||
n.entity_type AS entity_type,
|
||||
n.name AS name,
|
||||
COALESCE(n.fact_summary, '') AS fact_summary,
|
||||
n.group_id AS group_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.id AS id
|
||||
@@ -91,9 +92,9 @@ class DataConfigRepository:
|
||||
# Edges between extracted entities within group/app/user
|
||||
SEARCH_FOR_EDGES = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
r.group_id AS group_id,
|
||||
r.end_user_id AS end_user_id,
|
||||
r.apply_id AS apply_id,
|
||||
r.user_id AS user_id,
|
||||
elementId(r) AS rel_id,
|
||||
@@ -107,7 +108,7 @@ class DataConfigRepository:
|
||||
@staticmethod
|
||||
def update_reflection_config(
|
||||
db: Session,
|
||||
config_id: int,
|
||||
config_id: uuid.UUID,
|
||||
enable_self_reflexion: bool,
|
||||
iteration_period: str,
|
||||
reflexion_range: str,
|
||||
@@ -115,7 +116,7 @@ class DataConfigRepository:
|
||||
reflection_model_id: str,
|
||||
memory_verify: bool,
|
||||
quality_assessment: bool
|
||||
) -> DataConfig:
|
||||
) -> MemoryConfig:
|
||||
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -130,28 +131,28 @@ class DataConfigRepository:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Data
|
||||
MemoryConfig
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
||||
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
|
||||
data_config_obj = db.scalars(stmt).first()
|
||||
if not data_config_obj:
|
||||
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
|
||||
memory_config_obj = db.scalars(stmt).first()
|
||||
if not memory_config_obj:
|
||||
raise BusinessException
|
||||
data_config_obj.enable_self_reflexion = enable_self_reflexion
|
||||
data_config_obj.iteration_period = iteration_period
|
||||
data_config_obj.reflexion_range = reflexion_range
|
||||
data_config_obj.baseline = baseline
|
||||
data_config_obj.reflection_model_id = reflection_model_id
|
||||
data_config_obj.memory_verify = memory_verify
|
||||
data_config_obj.quality_assessment = quality_assessment
|
||||
memory_config_obj.enable_self_reflexion = enable_self_reflexion
|
||||
memory_config_obj.iteration_period = iteration_period
|
||||
memory_config_obj.reflexion_range = reflexion_range
|
||||
memory_config_obj.baseline = baseline
|
||||
memory_config_obj.reflection_model_id = reflection_model_id
|
||||
memory_config_obj.memory_verify = memory_verify
|
||||
memory_config_obj.quality_assessment = quality_assessment
|
||||
|
||||
return data_config_obj
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -162,13 +163,13 @@ class DataConfigRepository:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
||||
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
|
||||
data_config = db.scalars(stmt).first()
|
||||
if not data_config:
|
||||
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
|
||||
memory_config = db.scalars(stmt).first()
|
||||
if not memory_config:
|
||||
raise RuntimeError("reflection config not found")
|
||||
return data_config
|
||||
return memory_config
|
||||
@staticmethod
|
||||
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig:
|
||||
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
|
||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -180,11 +181,11 @@ class DataConfigRepository:
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id)
|
||||
data_config = db.scalars(stmt).first()
|
||||
if not data_config:
|
||||
stmt = select(MemoryConfig).where(MemoryConfig.workspace_id == workspace_id)
|
||||
memory_config = db.scalars(stmt).first()
|
||||
if not memory_config:
|
||||
raise RuntimeError("reflection config not found")
|
||||
return data_config
|
||||
return memory_config
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -208,20 +209,21 @@ class DataConfigRepository:
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
|
||||
"""创建数据配置
|
||||
def create(db: Session, params: ConfigParamsCreate) -> MemoryConfig:
|
||||
"""创建记忆配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
params: 配置参数创建模型
|
||||
|
||||
Returns:
|
||||
DataConfig: 创建的配置对象
|
||||
MemoryConfig: 创建的配置对象
|
||||
"""
|
||||
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
db_logger.debug(f"创建记忆配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
try:
|
||||
db_config = DataConfig(
|
||||
db_config = MemoryConfig(
|
||||
config_id=uuid.uuid4(),
|
||||
config_name=params.config_name,
|
||||
config_desc=params.config_desc,
|
||||
workspace_id=params.workspace_id,
|
||||
@@ -232,16 +234,16 @@ class DataConfigRepository:
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
db_logger.info(f"记忆配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
|
||||
db_logger.error(f"创建记忆配置失败: {params.config_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[MemoryConfig]:
|
||||
"""更新基础配置
|
||||
|
||||
Args:
|
||||
@@ -249,17 +251,17 @@ class DataConfigRepository:
|
||||
update: 配置更新模型
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
|
||||
db_logger.debug(f"更新记忆配置: config_id={update.config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
@@ -277,17 +279,17 @@ class DataConfigRepository:
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
db_logger.info(f"记忆配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
|
||||
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
|
||||
"""更新记忆萃取引擎配置
|
||||
|
||||
Args:
|
||||
@@ -295,7 +297,7 @@ class DataConfigRepository:
|
||||
update: 萃取配置更新模型
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
@@ -303,9 +305,9 @@ class DataConfigRepository:
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段映射
|
||||
@@ -360,7 +362,7 @@ class DataConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[MemoryConfig]:
|
||||
"""更新遗忘引擎配置
|
||||
|
||||
Args:
|
||||
@@ -368,7 +370,7 @@ class DataConfigRepository:
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
@@ -376,9 +378,9 @@ class DataConfigRepository:
|
||||
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
@@ -408,7 +410,7 @@ class DataConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
Args:
|
||||
@@ -421,7 +423,7 @@ class DataConfigRepository:
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
|
||||
return None
|
||||
@@ -457,7 +459,7 @@ class DataConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]:
|
||||
"""获取遗忘配置,通过主键查询某条配置
|
||||
|
||||
Args:
|
||||
@@ -470,7 +472,7 @@ class DataConfigRepository:
|
||||
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
|
||||
return None
|
||||
@@ -489,39 +491,39 @@ class DataConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
|
||||
"""根据ID获取数据配置
|
||||
def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]:
|
||||
"""根据ID获取记忆配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 配置对象,不存在则返回None
|
||||
Optional[MemoryConfig]: 配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
|
||||
db_logger.debug(f"根据ID查询记忆配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
db_logger.debug(f"记忆配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
else:
|
||||
db_logger.debug(f"数据配置不存在: config_id={config_id}")
|
||||
db_logger.debug(f"记忆配置不存在: config_id={config_id}")
|
||||
return config
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
|
||||
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
|
||||
"""Get data config and its associated workspace information
|
||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
|
||||
"""Get memory config and its associated workspace information
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
config_id: Configuration ID
|
||||
|
||||
Returns:
|
||||
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
|
||||
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when config exists but workspace doesn't
|
||||
@@ -541,19 +543,19 @@ class DataConfigRepository:
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
|
||||
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
|
||||
|
||||
try:
|
||||
# Use join query to get both config and workspace
|
||||
result = db.query(DataConfig, Workspace).join(
|
||||
Workspace, DataConfig.workspace_id == Workspace.id
|
||||
).filter(DataConfig.config_id == config_id).first()
|
||||
result = db.query(MemoryConfig, Workspace).join(
|
||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||
).filter(MemoryConfig.config_id == config_id).first()
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not result:
|
||||
# Check if config exists but workspace is missing
|
||||
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
if config_only:
|
||||
if config_only.workspace_id is None:
|
||||
config_logger.error(
|
||||
@@ -566,7 +568,7 @@ class DataConfigRepository:
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} has no associated workspace ID")
|
||||
db_logger.error(f"Memory config {config_id} has no associated workspace ID")
|
||||
raise ValueError(f"Configuration {config_id} has no associated workspace")
|
||||
else:
|
||||
config_logger.error(
|
||||
@@ -579,7 +581,7 @@ class DataConfigRepository:
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
||||
|
||||
config_logger.debug(
|
||||
@@ -591,7 +593,7 @@ class DataConfigRepository:
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.debug(f"Data config not found: config_id={config_id}")
|
||||
db_logger.debug(f"Memory config not found: config_id={config_id}")
|
||||
return None
|
||||
|
||||
config, workspace = result
|
||||
@@ -611,7 +613,7 @@ class DataConfigRepository:
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
|
||||
except ValueError:
|
||||
@@ -633,10 +635,10 @@ class DataConfigRepository:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
|
||||
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
Args:
|
||||
@@ -644,17 +646,17 @@ class DataConfigRepository:
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
Returns:
|
||||
List[DataConfig]: 配置列表
|
||||
List[MemoryConfig]: 配置列表
|
||||
"""
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
query = db.query(DataConfig)
|
||||
query = db.query(MemoryConfig)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(DataConfig.workspace_id == workspace_id)
|
||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||
|
||||
configs = query.order_by(desc(DataConfig.updated_at)).all()
|
||||
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
@@ -664,8 +666,8 @@ class DataConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, config_id: int) -> bool:
|
||||
"""删除数据配置
|
||||
def delete(db: Session, config_id: uuid.UUID) -> bool:
|
||||
"""删除记忆配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
@@ -674,22 +676,22 @@ class DataConfigRepository:
|
||||
Returns:
|
||||
bool: 删除成功返回True,配置不存在返回False
|
||||
"""
|
||||
db_logger.debug(f"删除数据配置: config_id={config_id}")
|
||||
db_logger.debug(f"删除记忆配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={config_id}")
|
||||
db_logger.warning(f"记忆配置不存在: config_id={config_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_config)
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"数据配置删除成功: config_id={config_id}")
|
||||
db_logger.info(f"记忆配置删除成功: config_id={config_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
|
||||
db_logger.error(f"删除记忆配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import and_, desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
|
||||
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageService
|
||||
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
|
||||
|
||||
db_logger = get_db_logger()
|
||||
@@ -28,7 +28,7 @@ class MemoryPerceptualRepository:
|
||||
file_ext: str,
|
||||
summary: Optional[str] = None,
|
||||
meta_data: Optional[dict] = None,
|
||||
storage_service: FileStorageType = FileStorageType.LOCAL
|
||||
storage_service: FileStorageService = FileStorageService.LOCAL
|
||||
|
||||
) -> MemoryPerceptualModel:
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
|
||||
"id": stable_edge_id,
|
||||
"source": chunk.id,
|
||||
"target": stmt.id,
|
||||
"group_id": getattr(stmt, 'group_id', None),
|
||||
"end_user_id": getattr(stmt, 'end_user_id', None),
|
||||
"user_id":getattr(stmt, 'user_id', None),
|
||||
"apply_id": getattr(stmt, 'apply_id', None),
|
||||
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
||||
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
||||
edges.append({
|
||||
"summary_id": s.id,
|
||||
"chunk_id": chunk_id,
|
||||
"group_id": s.group_id,
|
||||
"end_user_id": s.end_user_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
|
||||
@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
|
||||
print(f"All group_id: {group_id} node and edge deleted successfully")
|
||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
for dialogue in dialogues:
|
||||
flattened_dialogues.append({
|
||||
"id": dialogue.id,
|
||||
"group_id": dialogue.group_id,
|
||||
"user_id": dialogue.user_id,
|
||||
"apply_id": dialogue.apply_id,
|
||||
"end_user_id": dialogue.end_user_id,
|
||||
"run_id": dialogue.run_id,
|
||||
"ref_id": dialogue.ref_id,
|
||||
"name": dialogue.name,
|
||||
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
flattened_statement = {
|
||||
"id": statement.id,
|
||||
"name": statement.name,
|
||||
"group_id": statement.group_id,
|
||||
"user_id": statement.user_id,
|
||||
"apply_id": statement.apply_id,
|
||||
"end_user_id": statement.end_user_id,
|
||||
"run_id": statement.run_id,
|
||||
"chunk_id": statement.chunk_id,
|
||||
# "created_at": statement.created_at.isoformat(),
|
||||
@@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
flattened_chunk = {
|
||||
"id": chunk.id,
|
||||
"name": chunk.name,
|
||||
"group_id": chunk.group_id,
|
||||
"user_id": chunk.user_id,
|
||||
"apply_id": chunk.apply_id,
|
||||
"end_user_id": chunk.end_user_id,
|
||||
"run_id": chunk.run_id,
|
||||
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
||||
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
||||
@@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
flattened.append({
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"group_id": s.group_id,
|
||||
"user_id": s.user_id,
|
||||
"apply_id": s.apply_id,
|
||||
"end_user_id": s.end_user_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
|
||||
@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
|
||||
Example:
|
||||
>>> results = await repository.find(
|
||||
... {"group_id": "group_123", "user_id": "user_456"},
|
||||
... {"end_user_id": "group_123", "user_id": "user_456"},
|
||||
... limit=50
|
||||
... )
|
||||
"""
|
||||
|
||||
@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
|
||||
UNWIND $dialogues AS dialogue
|
||||
MERGE (n:Dialogue {id: dialogue.id})
|
||||
SET n.uuid = coalesce(n.uuid, dialogue.id),
|
||||
n.group_id = dialogue.group_id,
|
||||
n.user_id = dialogue.user_id,
|
||||
n.apply_id = dialogue.apply_id,
|
||||
n.end_user_id = dialogue.end_user_id,
|
||||
n.run_id = dialogue.run_id,
|
||||
n.ref_id = dialogue.ref_id,
|
||||
n.created_at = dialogue.created_at,
|
||||
@@ -22,9 +20,7 @@ SET s += {
|
||||
id: statement.id,
|
||||
run_id: statement.run_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
group_id: statement.group_id,
|
||||
user_id: statement.user_id,
|
||||
apply_id: statement.apply_id,
|
||||
end_user_id: statement.end_user_id,
|
||||
stmt_type: statement.stmt_type,
|
||||
statement: statement.statement,
|
||||
emotion_intensity: statement.emotion_intensity,
|
||||
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
|
||||
SET c += {
|
||||
id: chunk.id,
|
||||
name: chunk.name,
|
||||
group_id: chunk.group_id,
|
||||
user_id: chunk.user_id,
|
||||
apply_id: chunk.apply_id,
|
||||
end_user_id: chunk.end_user_id,
|
||||
run_id: chunk.run_id,
|
||||
created_at: chunk.created_at,
|
||||
expired_at: chunk.expired_at,
|
||||
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
|
||||
UNWIND $entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id})
|
||||
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
|
||||
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
|
||||
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
|
||||
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
|
||||
e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
|
||||
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
|
||||
e.created_at = CASE
|
||||
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
||||
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
|
||||
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
|
||||
ENTITY_RELATIONSHIP_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Match entities by stable id within group, do not constrain by run_id
|
||||
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
|
||||
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
|
||||
// Match entities by stable id within end_user_id, do not constrain by run_id
|
||||
MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
|
||||
MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
|
||||
// Avoid duplicate edges across runs for the same endpoints
|
||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||
SET r.predicate = rel.predicate,
|
||||
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
r.run_id = rel.run_id,
|
||||
r.group_id = rel.group_id
|
||||
r.end_user_id = rel.end_user_id
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
||||
SET e += {
|
||||
name: entity.name,
|
||||
group_id: entity.group_id,
|
||||
end_user_id: entity.end_user_id,
|
||||
run_id: entity.run_id,
|
||||
description: entity.description,
|
||||
chunk_id: entity.chunk_id,
|
||||
@@ -175,11 +167,11 @@ RETURN e.id AS id
|
||||
SAVE_STRONG_TRIPLE_ENTITIES = """
|
||||
UNWIND $items AS item
|
||||
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
||||
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
|
||||
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET s.is_strong = true
|
||||
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
||||
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
|
||||
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET o.is_strong = true
|
||||
"""
|
||||
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
|
||||
// 仅按端点去重,关系属性可更新
|
||||
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
||||
SET e.uuid = edge.id,
|
||||
e.group_id = edge.group_id,
|
||||
e.end_user_id = edge.end_user_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
RETURN e.uuid AS uuid
|
||||
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
|
||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
|
||||
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
||||
SET e.group_id = edge.group_id,
|
||||
SET e.end_user_id = edge.end_user_id,
|
||||
e.run_id = edge.run_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
|
||||
STATEMENT_ENTITY_EDGE_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
|
||||
// Entities are shared across runs within a group; do not constrain by run_id
|
||||
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
|
||||
// Entities are shared across runs within end_user_id; do not constrain by run_id
|
||||
MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
|
||||
// Avoid duplicate edges across runs for same endpoints
|
||||
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
||||
SET r.group_id = rel.group_id,
|
||||
SET r.end_user_id = rel.end_user_id,
|
||||
r.run_id = rel.run_id,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS e, score
|
||||
WHERE e.name_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.end_user_id AS end_user_id,
|
||||
e.entity_type AS entity_type,
|
||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS s, score
|
||||
WHERE s.statement_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.end_user_id AS end_user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS c, score
|
||||
WHERE c.chunk_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR c.group_id = $group_id)
|
||||
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.end_user_id AS end_user_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||
@@ -292,12 +283,12 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.end_user_id AS end_user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
@@ -316,15 +307,13 @@ LIMIT $limit
|
||||
# 查询实体名称包含指定字符串的实体
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
WHERE ($group_id IS NULL OR e.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.end_user_id AS end_user_id,
|
||||
e.entity_type AS entity_type,
|
||||
e.apply_id AS apply_id,
|
||||
e.user_id AS user_id,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.entity_idx AS entity_idx,
|
||||
@@ -347,11 +336,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.end_user_id AS end_user_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.sequence_number AS sequence_number,
|
||||
@@ -413,10 +402,10 @@ LIMIT $limit
|
||||
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE ($group_id IS NULL OR d.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
||||
AND d.id = $dialog_id
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.end_user_id AS end_user_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at
|
||||
@@ -426,10 +415,10 @@ LIMIT $limit
|
||||
|
||||
SEARCH_CHUNK_BY_CHUNK_ID = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
AND c.id = $chunk_id
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.end_user_id AS end_user_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.created_at AS created_at,
|
||||
@@ -441,18 +430,14 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL = """
|
||||
MATCH (s:Statement)
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
|
||||
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.end_user_id AS end_user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
@@ -468,9 +453,7 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.end_user_id AS end_user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
@@ -499,15 +480,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -519,15 +496,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_BY_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -539,15 +512,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_G_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -559,15 +528,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_L_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -579,15 +544,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_G_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -599,15 +560,11 @@ LIMIT $limit
|
||||
|
||||
SEARCH_STATEMENTS_L_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
@@ -665,18 +622,18 @@ LIMIT $limit
|
||||
|
||||
# 根据id修改句子的invalid_at的值
|
||||
UPDATE_STATEMENT_INVALID_AT = """
|
||||
MATCH (n:Statement {group_id: $group_id, id: $id})
|
||||
MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
|
||||
SET n.invalid_at = $new_invalid_at
|
||||
"""
|
||||
|
||||
# MemorySummary keyword search using fulltext index
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||
WHERE ($group_id IS NULL OR m.group_id = $group_id)
|
||||
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.end_user_id AS end_user_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS m, score
|
||||
WHERE m.summary_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR m.group_id = $group_id)
|
||||
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.end_user_id AS end_user_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
|
||||
SET m += {
|
||||
id: summary.id,
|
||||
name: summary.name,
|
||||
group_id: summary.group_id,
|
||||
user_id: summary.user_id,
|
||||
apply_id: summary.apply_id,
|
||||
end_user_id: summary.end_user_id,
|
||||
run_id: summary.run_id,
|
||||
created_at: summary.created_at,
|
||||
expired_at: summary.expired_at,
|
||||
@@ -745,7 +700,7 @@ MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
|
||||
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
|
||||
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
|
||||
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
|
||||
SET r.group_id = e.group_id,
|
||||
SET r.end_user_id = e.end_user_id,
|
||||
r.run_id = e.run_id,
|
||||
r.created_at = e.created_at,
|
||||
r.expired_at = e.expired_at
|
||||
@@ -774,7 +729,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
end_user_id: rel.end_user_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
@@ -796,7 +751,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
end_user_id: rel.end_user_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
|
||||
|
||||
neo4j_statement_part = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
WHERE n.end_user_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
@@ -824,7 +779,7 @@ RETURN
|
||||
'''
|
||||
neo4j_statement_all = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
WHERE n.end_user_id = "{}"
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id
|
||||
@@ -832,7 +787,7 @@ RETURN
|
||||
'''
|
||||
neo4j_query_part = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
WHERE n.end_user_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
@@ -853,7 +808,7 @@ neo4j_query_part = """
|
||||
"""
|
||||
neo4j_query_all = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
WHERE n.end_user_id = "{}"
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
@@ -1027,14 +982,14 @@ RETURN DISTINCT
|
||||
|
||||
Memory_Space_User="""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE n.group_id = $group_id AND m.name="用户"
|
||||
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||
return DISTINCT elementId(m) as id
|
||||
"""
|
||||
Memory_Space_Entity="""
|
||||
MATCH (n)-[]-(m)
|
||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||
RETURN
|
||||
DISTINCT m.name as name,m.group_id as group_id
|
||||
DISTINCT m.name as name,m.end_user_id as end_user_id
|
||||
"""
|
||||
Memory_Space_Associative="""
|
||||
MATCH (u)-[]-(x)-[]-(h)
|
||||
|
||||
@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
"""对话仓储
|
||||
|
||||
管理对话节点的创建、查询、更新和删除操作。
|
||||
提供按group_id、user_id、ref_id等条件查询对话的方法。
|
||||
提供按end_user_id、user_id、ref_id等条件查询对话的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
|
||||
return DialogueNode(**n)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据group_id查询对话
|
||||
async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据end_user_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
return await self.find({"end_user_id": end_user_id}, limit=limit)
|
||||
|
||||
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据user_id查询对话
|
||||
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
user_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据group_id和user_id查询对话
|
||||
"""根据end_user_id和user_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
user_id: 用户ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "user_id": user_id},
|
||||
{"end_user_id": end_user_id, "user_id": user_id},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_recent_dialogs(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""查询最近的对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
days: 查询最近多少天的对话
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
@@ -164,22 +164,22 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
async def find_by_config_and_group(
|
||||
self,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据config_id和group_id查询对话
|
||||
"""根据config_id和end_user_id查询对话
|
||||
|
||||
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"config_id": config_id, "group_id": group_id},
|
||||
{"config_id": config_id, "end_user_id": end_user_id},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ class EmotionRepository:
|
||||
|
||||
async def get_emotion_tags(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
@@ -51,7 +51,7 @@ class EmotionRepository:
|
||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
end_user_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)
|
||||
start_date: 可选的开始日期(ISO格式字符串)
|
||||
end_date: 可选的结束日期(ISO格式字符串)
|
||||
@@ -65,8 +65,8 @@ class EmotionRepository:
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
|
||||
params = {"end_user_id": end_user_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
@@ -119,7 +119,7 @@ class EmotionRepository:
|
||||
|
||||
async def get_emotion_wordcloud(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -128,7 +128,7 @@ class EmotionRepository:
|
||||
查询情绪关键词及其频率,用于生成词云可视化。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
end_user_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤
|
||||
limit: 返回关键词的最大数量
|
||||
|
||||
@@ -140,8 +140,8 @@ class EmotionRepository:
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
|
||||
params = {"end_user_id": end_user_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
@@ -186,7 +186,7 @@ class EmotionRepository:
|
||||
|
||||
async def get_emotions_in_range(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
time_range: str = "30d"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取时间范围内的情绪数据
|
||||
@@ -194,7 +194,7 @@ class EmotionRepository:
|
||||
查询指定时间范围内的所有情绪数据,用于健康指数计算。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
end_user_id: 用户组ID(宿主ID)
|
||||
time_range: 时间范围(7d/30d/90d)
|
||||
|
||||
Returns:
|
||||
@@ -214,7 +214,7 @@ class EmotionRepository:
|
||||
# 优化的 Cypher 查询:使用字符串比较避免时区问题
|
||||
query = """
|
||||
MATCH (s:Statement)
|
||||
WHERE s.group_id = $group_id
|
||||
WHERE s.end_user_id = $end_user_id
|
||||
AND s.emotion_type IS NOT NULL
|
||||
AND s.created_at >= $start_date
|
||||
RETURN s.id as statement_id,
|
||||
@@ -227,7 +227,7 @@ class EmotionRepository:
|
||||
try:
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=start_date
|
||||
)
|
||||
formatted_results = [
|
||||
|
||||
@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
|
||||
'created_at': edge.created_at.isoformat(),
|
||||
'expired_at': edge.expired_at.isoformat(),
|
||||
'run_id': edge.run_id,
|
||||
'group_id': edge.group_id,
|
||||
'user_id': edge.user_id,
|
||||
'apply_id': edge.apply_id,
|
||||
'end_user_id': edge.end_user_id,
|
||||
}
|
||||
all_relationships.append(relationship)
|
||||
|
||||
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
|
||||
"id": edge.id,
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"run_id": edge.run_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
|
||||
edge_data = {
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"run_id": edge.run_id,
|
||||
"connect_strength": edge.connect_strength,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
|
||||
@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
|
||||
connector: Neo4j连接器
|
||||
nodes: 节点列表,每个节点必须包含 'id' 字段
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=unique_node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
results: 搜索结果字典,包含不同类型节点的列表
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
||||
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
|
||||
connector=connector,
|
||||
nodes=results[key],
|
||||
node_label=label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -236,7 +236,7 @@ async def search_graph(
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
group_id: Optional group filter
|
||||
end_user_id: Optional group filter
|
||||
limit: Max results per category
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
@@ -254,7 +254,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
@@ -263,7 +263,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
@@ -272,7 +272,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
@@ -281,7 +281,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
@@ -310,12 +310,12 @@ async def search_graph(
|
||||
key in include and key in results and results[key]
|
||||
for key in ['statements', 'entities', 'chunks']
|
||||
)
|
||||
|
||||
|
||||
if needs_activation_update:
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -325,7 +325,7 @@ async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -337,7 +337,7 @@ async def search_graph_by_embedding(
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
- Filters by group_id if provided
|
||||
- Filters by end_user_id if provided
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
@@ -346,7 +346,7 @@ async def search_graph_by_embedding(
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||
@@ -361,7 +361,7 @@ async def search_graph_by_embedding(
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
@@ -371,7 +371,7 @@ async def search_graph_by_embedding(
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
@@ -381,7 +381,7 @@ async def search_graph_by_embedding(
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
@@ -391,7 +391,7 @@ async def search_graph_by_embedding(
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
@@ -400,7 +400,7 @@ async def search_graph_by_embedding(
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
@@ -429,13 +429,13 @@ async def search_graph_by_embedding(
|
||||
key in include and key in results and results[key]
|
||||
for key in ['statements', 'entities', 'chunks']
|
||||
)
|
||||
|
||||
|
||||
if needs_activation_update:
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
@@ -445,7 +445,7 @@ async def search_graph_by_embedding(
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
@@ -453,7 +453,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
|
||||
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
|
||||
- 保留并发控制与返回结构(incoming_id -> [db_entity_props...]);
|
||||
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
|
||||
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
|
||||
@@ -477,7 +477,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=100,
|
||||
)
|
||||
except Exception:
|
||||
@@ -501,7 +501,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name.lower(),
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=100,
|
||||
)
|
||||
for r in rows:
|
||||
@@ -532,9 +532,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -547,32 +545,30 @@ async def search_graph_by_keyword_temporal(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
logger.warning(f"query_text cannot be empty")
|
||||
print(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -580,9 +576,7 @@ async def search_graph_by_keyword_temporal(
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -595,14 +589,12 @@ async def search_graph_by_temporal(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
@@ -610,16 +602,16 @@ async def search_graph_by_temporal(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
|
||||
logger.debug(f"Temporal search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -628,23 +620,23 @@ async def search_graph_by_temporal(
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
|
||||
- Matches dialogues with dialog_id
|
||||
- Optionally filters by group_id
|
||||
- Optionally filters by end_user_id
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
logger.warning(f"dialog_id cannot be empty")
|
||||
print(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
dialog_id=dialog_id,
|
||||
limit=limit,
|
||||
)
|
||||
@@ -654,15 +646,15 @@ async def search_graph_by_dialog_id(
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
logger.warning(f"chunk_id cannot be empty")
|
||||
print(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
chunk_id=chunk_id,
|
||||
limit=limit,
|
||||
)
|
||||
@@ -671,9 +663,9 @@ async def search_graph_by_chunk_id(
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -683,37 +675,37 @@ async def search_graph_by_created_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -723,37 +715,37 @@ async def search_graph_by_valid_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -763,37 +755,37 @@ async def search_graph_g_created_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -803,37 +795,37 @@ async def search_graph_g_valid_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -843,37 +835,37 @@ async def search_graph_l_created_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -883,28 +875,28 @@ async def search_graph_l_valid_at(
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Optionally filters by end_user_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""Memory Summary Repository
|
||||
|
||||
Manages CRUD operations for MemorySummary nodes.
|
||||
Provides methods to query summaries by group_id, user_id, and time ranges.
|
||||
Provides methods to query summaries by end_user_id, user_id, and time ranges.
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j connector instance
|
||||
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
return dict(n)
|
||||
|
||||
async def find_by_group_id(
|
||||
async def find_by_end_user_id(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by group_id
|
||||
"""Query memory summaries by end_user_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
end_user_id: Group ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
params = {"end_user_id": end_user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by both group_id and user_id
|
||||
"""Query memory summaries by both end_user_id and user_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
end_user_id: Group ID to filter by
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id AND n.user_id = $user_id
|
||||
WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
|
||||
params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
async def find_recent_summaries(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 1000
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query recent memory summaries
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
end_user_id: Group ID to filter by
|
||||
days: Number of recent days to query
|
||||
limit: Maximum number of results to return
|
||||
|
||||
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
@@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
@@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
async def find_by_content_keywords(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
keywords: List[str],
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by content keywords
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
end_user_id: Group ID to filter by
|
||||
keywords: List of keywords to search for in content
|
||||
limit: Maximum number of results to return
|
||||
|
||||
@@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""
|
||||
# Build keyword search conditions
|
||||
keyword_conditions = []
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
params = {"end_user_id": end_user_id, "limit": limit}
|
||||
|
||||
for i, keyword in enumerate(keywords):
|
||||
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
|
||||
@@ -243,7 +243,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
AND ({keyword_filter})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
@@ -253,21 +253,21 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def get_summary_count_by_group(self, group_id: str) -> int:
|
||||
async def get_summary_count_by_group(self, end_user_id: str) -> int:
|
||||
"""Get count of memory summaries for a group
|
||||
|
||||
Args:
|
||||
group_id: Group ID to count summaries for
|
||||
end_user_id: Group ID to count summaries for
|
||||
|
||||
Returns:
|
||||
int: Number of memory summaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
results = await self.connector.execute_query(query, end_user_id=end_user_id)
|
||||
return results[0]['count'] if results else 0
|
||||
|
||||
@@ -70,11 +70,7 @@ class Neo4jConnector:
|
||||
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> results = await connector.execute_query(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name="Alice"
|
||||
... )
|
||||
|
||||
"""
|
||||
result = await self.driver.execute_query(
|
||||
query,
|
||||
@@ -98,17 +94,7 @@ class Neo4jConnector:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def create_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "CREATE (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_write_transaction(
|
||||
... create_node, name="Alice"
|
||||
... )
|
||||
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_write(transaction_func, **kwargs)
|
||||
@@ -126,45 +112,33 @@ class Neo4jConnector:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def get_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_read_transaction(
|
||||
... get_node, name="Alice"
|
||||
... )
|
||||
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_read(transaction_func, **kwargs)
|
||||
|
||||
async def delete_group(self, group_id: str):
|
||||
async def delete_group(self, end_user_id: str):
|
||||
"""删除指定组的所有数据
|
||||
|
||||
删除所有属于指定group_id的节点和边。
|
||||
删除所有属于指定end_user_id的节点和边。
|
||||
这是一个危险操作,会永久删除数据。
|
||||
|
||||
Args:
|
||||
group_id: 要删除的组ID
|
||||
end_user_id: 要删除的组ID
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> await connector.delete_group("group_123")
|
||||
Group group_123 deleted.
|
||||
"""
|
||||
# 删除节点(DETACH DELETE会同时删除相关的边)
|
||||
await self.driver.execute_query(
|
||||
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
|
||||
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
# 删除独立的边(如果有的话)
|
||||
await self.driver.execute_query(
|
||||
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
|
||||
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
print(f"Group {group_id} deleted.")
|
||||
print(f"Group {end_user_id} deleted.")
|
||||
|
||||
@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
"""陈述句仓储
|
||||
|
||||
管理陈述句节点的创建、查询、更新和删除操作。
|
||||
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
|
||||
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
|
||||
@@ -299,6 +299,18 @@ class AppRelease(BaseModel):
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def parse_config(cls, v):
|
||||
"""处理 config 字段,如果是字符串则解析为字典"""
|
||||
if isinstance(v, str):
|
||||
import json
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return v if v is not None else {}
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""情绪分析相关的请求和响应模型"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class EmotionTagsRequest(BaseModel):
|
||||
"""获取情绪标签统计请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
end_user_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)")
|
||||
end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)")
|
||||
@@ -14,14 +15,14 @@ class EmotionTagsRequest(BaseModel):
|
||||
|
||||
class EmotionWordcloudRequest(BaseModel):
|
||||
"""获取情绪词云数据请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
end_user_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
|
||||
|
||||
|
||||
class EmotionHealthRequest(BaseModel):
|
||||
"""获取情绪健康指数请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
end_user_id: str = Field(..., description="组ID")
|
||||
time_range: str = Field("30d", description="时间范围(7d/30d/90d)")
|
||||
|
||||
|
||||
@@ -29,8 +30,8 @@ class EmotionHealthRequest(BaseModel):
|
||||
|
||||
class EmotionSuggestionsRequest(BaseModel):
|
||||
"""获取个性化情绪建议请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)")
|
||||
end_user_id: str = Field(..., description="组ID")
|
||||
config_id: Optional[UUID] = Field(None, description="配置ID(用于指定LLM模型)")
|
||||
|
||||
|
||||
class EmotionGenerateSuggestionsRequest(BaseModel):
|
||||
|
||||
@@ -7,11 +7,11 @@ class UserInput(BaseModel):
|
||||
message: str
|
||||
history: list[dict]
|
||||
search_switch: str
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
messages: list[dict]
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
@@ -35,7 +35,7 @@ class ConfigurationError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@@ -72,7 +72,7 @@ class WorkspaceNotFoundError(ConfigurationError):
|
||||
def __init__(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
@@ -89,7 +89,7 @@ class ModelNotFoundError(ConfigurationError):
|
||||
self,
|
||||
model_id: Union[str, UUID],
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
@@ -112,7 +112,7 @@ class ModelInactiveError(ConfigurationError):
|
||||
model_id: Union[str, UUID],
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
@@ -136,7 +136,7 @@ class InvalidConfigError(ConfigurationError):
|
||||
message: str,
|
||||
field_name: Optional[str] = None,
|
||||
invalid_value: Optional[Any] = None,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
):
|
||||
context = {}
|
||||
@@ -155,7 +155,7 @@ class InvalidConfigError(ConfigurationError):
|
||||
class MemoryConfigValidation(BaseModel):
|
||||
"""Pydantic model for validating memory configuration data from database."""
|
||||
|
||||
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
|
||||
config_id: UUID = Field(..., description="Configuration ID (UUID)")
|
||||
config_name: str = Field(..., min_length=1, max_length=255)
|
||||
workspace_id: UUID = Field(..., description="Workspace UUID")
|
||||
workspace_name: str = Field(..., min_length=1, max_length=255)
|
||||
@@ -275,7 +275,7 @@ class ModelValidation(BaseModel):
|
||||
|
||||
|
||||
def validate_memory_config_data(
|
||||
config_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
config_data: Dict[str, Any], config_id: Optional[UUID] = None
|
||||
) -> MemoryConfigValidation:
|
||||
"""Validate memory configuration data using Pydantic model."""
|
||||
try:
|
||||
@@ -302,7 +302,7 @@ def validate_memory_config_data(
|
||||
|
||||
|
||||
def validate_workspace_data(
|
||||
workspace_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
workspace_data: Dict[str, Any], config_id: Optional[UUID] = None
|
||||
) -> WorkspaceValidation:
|
||||
"""Validate workspace data using Pydantic model."""
|
||||
try:
|
||||
@@ -331,7 +331,7 @@ def validate_workspace_data(
|
||||
|
||||
|
||||
def validate_model_data(
|
||||
model_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
model_data: Dict[str, Any], config_id: Optional[UUID] = None
|
||||
) -> ModelValidation:
|
||||
"""Validate model data using Pydantic model."""
|
||||
try:
|
||||
@@ -364,7 +364,7 @@ def validate_model_data(
|
||||
class MemoryConfig:
|
||||
"""Immutable memory configuration loaded from database."""
|
||||
|
||||
config_id: int
|
||||
config_id: UUID
|
||||
config_name: str
|
||||
workspace_id: UUID
|
||||
workspace_name: str
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
|
||||
|
||||
class PerceptualFilter(BaseModel):
|
||||
@@ -38,12 +38,14 @@ class PerceptualMemoryItem(BaseModel):
|
||||
"""感知记忆项"""
|
||||
id: uuid.UUID = Field(..., description="Unique memory ID")
|
||||
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
|
||||
storage_service: FileStorageService = Field(..., description="Storage service for file")
|
||||
file_path: str = Field(..., description="File path in the storage service")
|
||||
file_ext: str = Field(..., description="File extension")
|
||||
file_name: str = Field(..., description="File name")
|
||||
file_ext: str = Field(..., description="File extension")
|
||||
summary: Optional[str] = Field(None, description="summary")
|
||||
storage_type: FileStorageType = Field(..., description="Storage type for file")
|
||||
meta_data: Optional[dict] = Field(None, description="Metadata information")
|
||||
created_time: int = Field(..., description="create time")
|
||||
|
||||
topic: str = Field(..., description="topic")
|
||||
domain: str = Field(..., description="domain")
|
||||
keywords: list[str] = Field(..., description="keywords")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -9,7 +10,7 @@ class OptimizationStrategy(str, Enum):
|
||||
ACCURACY_FIRST = "accuracy_first"
|
||||
BALANCED = "balanced"
|
||||
class Memory_Reflection(BaseModel):
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
reflection_enabled: bool
|
||||
reflection_period_in_hours: str
|
||||
reflexion_range: Optional[str] = "partial"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
所有的内容是放错误地方了,应该放在models
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Literal, Union
|
||||
@@ -8,20 +8,8 @@ import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 原 UserInput 相关 Schema (保留原有功能)
|
||||
# ============================================================================
|
||||
class UserInput(BaseModel):
|
||||
message: str
|
||||
history: list[dict]
|
||||
search_switch: str
|
||||
group_id: str
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
message: str
|
||||
group_id: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 从 json_schema.py 迁移的 Schema
|
||||
@@ -159,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
|
||||
# Composite key identifying a config row
|
||||
class ConfigKey(BaseModel): # 配置参数键模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_id: int = Field("config_id", description="配置唯一标识(字符串)")
|
||||
config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID)")
|
||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||
|
||||
@@ -250,17 +238,17 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_id: int = Field("配置ID", description="配置ID(字符串)")
|
||||
config_id: uuid.UUID = Field("配置ID", description="配置ID(UUID)")
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[uuid.UUID] = None
|
||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
|
||||
|
||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[uuid.UUID] = None
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
@@ -327,14 +315,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
||||
|
||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||
# 遗忘引擎配置参数更新模型
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[uuid.UUID] = None
|
||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||
|
||||
|
||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
config_id: int = Field(..., description="配置ID(唯一)")
|
||||
config_id: uuid.UUID = Field(..., description="配置ID(唯一)")
|
||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
@@ -342,7 +330,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[uuid.UUID] = None
|
||||
user_id: Optional[str] = None
|
||||
apply_id: Optional[str] = None
|
||||
|
||||
@@ -418,7 +406,7 @@ class ForgettingConfigResponse(BaseModel):
|
||||
"""遗忘引擎配置响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: uuid.UUID = Field(..., description="配置ID")
|
||||
decay_constant: float = Field(..., description="衰减常数 d")
|
||||
lambda_time: float = Field(..., description="时间衰减参数")
|
||||
lambda_mem: float = Field(..., description="记忆衰减参数")
|
||||
@@ -436,7 +424,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
||||
"""遗忘引擎配置更新请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: uuid.UUID = Field(..., description="配置ID")
|
||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||
@@ -511,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
|
||||
|
||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
||||
config_id: Optional[uuid.UUID] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
||||
|
||||
|
||||
class ForgettingCurveResponse(BaseModel):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
import datetime
|
||||
import uuid
|
||||
@@ -91,6 +91,18 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def parse_config(cls, v):
|
||||
"""处理 config 字段,如果是字符串则解析为字典"""
|
||||
if isinstance(v, str):
|
||||
import json
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return v
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
@@ -88,6 +88,18 @@ class SharedReleaseInfo(BaseModel):
|
||||
# 嵌入配置
|
||||
allow_embed: bool
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def parse_config(cls, v):
|
||||
"""处理 config 字段,如果是字符串则解析为字典"""
|
||||
if isinstance(v, str):
|
||||
import json
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return v if v is not None else {}
|
||||
|
||||
|
||||
class EmbedCode(BaseModel):
|
||||
"""嵌入代码"""
|
||||
|
||||
@@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="2",
|
||||
|
||||
@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 调用仓储层查询
|
||||
tags = await self.emotion_repo.get_emotion_tags(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
emotion_type=emotion_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 调用仓储层查询
|
||||
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
emotion_type=emotion_type,
|
||||
limit=limit
|
||||
)
|
||||
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 获取时间范围内的情绪数据
|
||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
time_range=time_range
|
||||
)
|
||||
|
||||
@@ -505,7 +505,7 @@ class EmotionAnalyticsService:
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
config_id=(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 3. 获取情绪数据用于模式分析
|
||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
time_range="30d"
|
||||
)
|
||||
|
||||
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
|
||||
# 查询用户的实体和标签
|
||||
query = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.group_id = $group_id
|
||||
WHERE e.end_user_id = $end_user_id
|
||||
RETURN e.name as name, e.type as type
|
||||
ORDER BY e.created_at DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
entities = await connector.execute_query(query, group_id=end_user_id)
|
||||
entities = await connector.execute_query(query, end_user_id=end_user_id)
|
||||
|
||||
# 提取兴趣标签
|
||||
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user