Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

0
api/app/__init__.py Normal file
View File

View File

@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional
from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success
from app.dependencies import get_current_user
@@ -32,11 +33,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID")
config_id: UUID = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID")
config_id: UUID = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: int = Query(..., description="配置ID"),
config_id: UUID = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):

View File

@@ -53,7 +53,7 @@ async def get_emotion_tags(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"emotion_type": request.emotion_type,
"start_date": request.start_date,
"end_date": request.end_date,
@@ -63,7 +63,7 @@ async def get_emotion_tags(
# 调用服务层
data = await emotion_service.get_emotion_tags(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
emotion_type=request.emotion_type,
start_date=request.start_date,
end_date=request.end_date,
@@ -73,7 +73,7 @@ async def get_emotion_tags(
api_logger.info(
"情绪标签统计获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", []))
}
@@ -84,7 +84,7 @@ async def get_emotion_tags(
except Exception as e:
api_logger.error(
f"获取情绪标签统计失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"emotion_type": request.emotion_type,
"limit": request.limit
}
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
# 调用服务层
data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
emotion_type=request.emotion_type,
limit=request.limit
)
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
api_logger.info(
"情绪词云数据获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"total_keywords": data.get("total_keywords", 0)
}
)
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
except Exception as e:
api_logger.error(
f"获取情绪词云数据失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -159,21 +159,21 @@ async def get_emotion_health(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"time_range": request.time_range
}
)
# 调用服务层
data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
time_range=request.time_range
)
api_logger.info(
"情绪健康指数获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"health_score": data.get("health_score", 0),
"level": data.get("level", "未知")
}
@@ -186,7 +186,7 @@ async def get_emotion_health(
except Exception as e:
api_logger.error(
f"获取情绪健康指数失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
"""获取个性化情绪建议(从缓存读取)
Args:
request: 包含 group_id 和可选的 config_id
request: 包含 end_user_id 和可选的 config_id
db: 数据库会话
current_user: 当前用户
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"config_id": request.config_id
}
)
# 从缓存获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
db=db
)
if data is None:
# 缓存不存在或已过期
api_logger.info(
f"用户 {request.group_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id}
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
extra={"end_user_id": request.end_user_id}
)
return fail(
BizCode.NOT_FOUND,
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
api_logger.info(
"个性化建议获取成功(缓存)",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(

View File

@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
user_id: str,
end_user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -137,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
@@ -146,20 +146,20 @@ async def get_preference_tags(
Returns:
List of preference tags from cache
"""
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -192,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
user_id: str,
end_user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
include_history: Whether to include historical trend data (ignored for cached data)
Returns:
Four-dimension personality portrait from cache
"""
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id)
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
user_id: str,
end_user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
include_trends: Whether to include trend analysis data (ignored for cached data)
Returns:
Interest area distribution from cache
"""
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
user_id: str,
end_user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -309,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
@@ -317,20 +317,20 @@ async def get_behavior_habits(
Returns:
List of behavioral habits from cache
"""
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -368,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)

View File

@@ -125,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Response with write operation status
@@ -160,19 +160,18 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.group_id,
messages_list, # 传递结构化消息列表
user_input.end_user_id,
messages_list,
config_id,
db,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -196,7 +195,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Task ID for tracking async operation
@@ -226,10 +225,10 @@ async def write_server_async(
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
@@ -255,7 +254,7 @@ async def read_server(
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and group_id
user_input: Read request with message, history, search_switch, and end_user_id
Returns:
Response with query answer
@@ -277,12 +276,13 @@ async def read_server(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.group_id,
user_input.end_user_id,
user_input.message,
user_input.history,
user_input.search_switch,
@@ -293,12 +293,12 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
group_id=user_input.group_id,
end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info,
history=history,
query=query,
@@ -404,7 +404,7 @@ async def read_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
@@ -448,7 +448,7 @@ async def get_read_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -525,7 +525,7 @@ async def get_write_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -579,16 +579,16 @@ async def status_type(
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and group_id
user_input: Request containing user message and end_user_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
@@ -596,11 +596,11 @@ async def status_type(
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
last_user_message,
user_input.config_id,
@@ -625,7 +625,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
@@ -698,7 +698,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user)
):
"""
获取工作空间下Popular Memory Tags,包含:
获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签

View File

@@ -11,6 +11,7 @@
"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -106,7 +107,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
group_id=end_user_id, # 服务层方法的参数名是 group_id
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: int,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -236,7 +237,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -246,7 +247,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。
Args:
group_id: 组ID即 end_user_id可选
end_user_id: 组ID即 end_user_id可选
current_user: 当前用户
db: 数据库会话
@@ -260,20 +261,20 @@ async def get_forgetting_stats(
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id
# 如果提供了 end_user_id通过它获取 config_id
config_id = None
if group_id:
if end_user_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -283,14 +284,14 @@ async def get_forgetting_stats(
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}"
f"end_user_id={end_user_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
group_id=group_id,
end_user_id=end_user_id,
config_id=config_id
)

View File

@@ -27,27 +27,27 @@ router = APIRouter(
)
@router.get("/{group_id}/count", response_model=ApiResponse)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
group_id: ID of the user group (usually end_user_id in this context)
end_user_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id)
count_stats = service.get_memory_count(end_user_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id)
visual_memory = service.get_latest_visual_memory(end_user_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}")
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id)
audio_memory = service.get_latest_audio_memory(end_user_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}")
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id)
text_memory = service.get_latest_text_memory(end_user_id)
if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}")
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}"
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
)
try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query)
timeline_data = service.get_time_line(end_user_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"error={str(e)}"
)
return fail(

View File

@@ -1,6 +1,7 @@
import asyncio
import time
import uuid
from uuid import UUID
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
@@ -11,7 +12,7 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
@@ -50,7 +51,7 @@ async def save_reflection_config(
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
data_config = DataConfigRepository.update_reflection_config(
memory_config = MemoryConfigRepository.update_reflection_config(
db,
config_id=config_id,
enable_self_reflexion=request.reflection_enabled,
@@ -63,17 +64,17 @@ async def save_reflection_config(
)
db.commit()
db.refresh(data_config)
db.refresh(memory_config)
reflection_result={
"config_id": data_config.config_id,
"enable_self_reflexion": data_config.enable_self_reflexion,
"iteration_period": data_config.iteration_period,
"reflexion_range": data_config.reflexion_range,
"baseline": data_config.baseline,
"reflection_model_id": data_config.reflection_model_id,
"memory_verify": data_config.memory_verify,
"quality_assessment": data_config.quality_assessment}
"config_id": memory_config.config_id,
"enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": memory_config.iteration_period,
"reflexion_range": memory_config.reflexion_range,
"baseline": memory_config.baseline,
"reflection_model_id": memory_config.reflection_model_id,
"memory_verify": memory_config.memory_verify,
"quality_assessment": memory_config.quality_assessment}
return success(data=reflection_result, msg="反思配置成功")
@@ -111,14 +112,14 @@ async def start_workspace_reflection(
reflection_results = []
for data in result['apps_detailed_info']:
if data['data_configs'] == []:
if data['memory_configs'] == []:
continue
releases = data['releases']
data_configs = data['data_configs']
memory_configs = data['memory_configs']
end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users):
for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况
print(base['config'])
try:
@@ -156,14 +157,14 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: int,
config_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询data_config表中的反思配置信息"""
"""通过config_id查询memory_config表中的反思配置信息"""
try:
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 构建返回数据
reflection_config = {
"config_id": result.config_id,
@@ -191,7 +192,7 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: int,
config_id: UUID,
language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@@ -200,8 +201,8 @@ async def reflection_run(
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
# 使用MemoryConfigRepository查询反思配置
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View File

@@ -1,5 +1,6 @@
import os
from typing import Optional
from uuid import UUID
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
@@ -160,7 +161,7 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: str,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -232,7 +233,7 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: str,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:

View File

@@ -20,18 +20,18 @@ router = APIRouter(
)
@router.get("/{group_id}/count", response_model=ApiResponse)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{group_id}/conversations", response_model=ApiResponse)
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
def get_conversations(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -39,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group.
Args:
group_id (UUID): The group identifier.
end_user_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
@@ -53,7 +53,7 @@ def get_conversations(
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
group_id
end_user_id
)
return success(data=[
{
@@ -63,7 +63,7 @@ def get_conversations(
], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse)
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
@@ -100,7 +100,7 @@ def get_messages(
return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse)
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -135,27 +135,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id
end_user_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}"
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
)
try:
if group_id:
if end_user_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
# 构建响应
result = {
"end_user_id": group_id,
"end_user_id": end_user_id,
"insight_success": insight_result["success"],
"summary_success": summary_result["success"],
"errors": []
@@ -175,9 +175,9 @@ async def generate_cache_api(
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存")
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")

View File

@@ -155,13 +155,13 @@ class LangChainAgent:
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# group_id=end_user_end,
# end_user_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
@@ -179,7 +179,7 @@ class LangChainAgent:
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
@@ -188,7 +188,7 @@ class LangChainAgent:
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
@@ -204,20 +204,20 @@ class LangChainAgent:
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
@@ -228,7 +228,7 @@ class LangChainAgent:
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"

View File

@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {}
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config
async def rag_knowledge(state,question):
kb_config = await rag_config(state)
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '')
end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": question,
"return_raw_results": True
}
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id
# 从state中获取end_user_id
import time
start=time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(group_id)
search_params = { "group_id": group_id, "return_raw_results": True }
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题

View File

@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id)
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '')
group_id = state.get("group_id", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
user_id=group_id,
user_id=end_user_id,
query=data,
apply_id=group_id,
group_id=group_id,
apply_id=end_user_id,
end_user_id=end_user_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
group_id=state.get("group_id", '')
end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance

View File

@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})

View File

@@ -1,23 +1,24 @@
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState:
"""
Write data to the database/file system.
Args:
state: WriteState containing messages, group_id, and memory_config
state: WriteState containing messages, end_user_id, and memory_config
Returns:
dict: Contains 'write_result' with status and data fields
"""
messages = state.get('messages', [])
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result = await write(
messages=structured_messages,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
end_user_id=end_user_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -79,7 +79,7 @@ async def make_read_graph():
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -95,9 +95,9 @@ async def main():
start=time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息
_intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str):
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
"""
def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
# 使用传入的参数或默认值
actual_group_id = group_id_param or group_id
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
results = await search_by_temporal(
group_id=actual_group_id,
end_user_id=actual_end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索
results = await search_by_keyword_temporal(
query_text=context,
group_id=group_id,
end_user_id=end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args:
memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等
**search_params: 搜索参数,包含end_user_id, limit, include等
"""
def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = {
"query_text": context,
"search_type": search_type,
"group_id": group_id or search_params.get("group_id"),
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
"""
async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context,
"search_type": search_type,
"limit": limit,
"group_id": group_id,
"end_user_id": end_user_id,
"clean_output": clean_output
})

View File

@@ -14,6 +14,7 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -26,9 +27,21 @@ async def make_write_graph():
"""
Create a write graph workflow for memory operations.
The workflow directly processes messages from the initial state
and saves them to Neo4j storage.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -42,7 +55,7 @@ async def make_write_graph():
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
group_id = 'new_2025test1103' # 组ID
end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话
@@ -54,9 +67,9 @@ async def main():
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
"end_user_id": end_user_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search(
self,
group_id: str,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
end_user_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True
)
# Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
end_user_id: str = "group_1",
messages: list = None,
ref_id: str = "wyl_20251027",
config_id: str = None
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
"""
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
conversation_messages = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
end_user_id=end_user_id,
config_id=config_id
)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
return [dialog_data]

View File

@@ -13,13 +13,11 @@ class WriteState(TypedDict):
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object
write_result: dict
data:str
data: str
class ReadState(TypedDict):
"""
@@ -29,7 +27,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加
loop_count: 遍历次数
search_switch: 搜索类型开关
group_id: 组标识
end_user_id: 组标识
config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果
@@ -40,7 +38,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int
search_switch: str
group_id: str
end_user_id: str
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"group_id": session.get('group_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id):
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
"""
import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
group_id = data.get('group_id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen:
# 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -29,20 +29,18 @@ logger = get_agent_logger(__name__)
async def write(
user_id: str,
apply_id: str,
group_id: str,
end_user_id: str,
memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027"
@@ -51,14 +49,14 @@ async def write(
embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy
config_id = str(memory_config.config_id)
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
logger.info(f"Workspace: {memory_config.workspace_name}")
logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
logger.info(f"end_user_id ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session
with get_db_context() as db:
@@ -83,9 +81,7 @@ async def write(
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
end_user_id=end_user_id,
messages=messages,
ref_id=ref_id,
config_id=config_id,

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args:
tags: 原始标签列表
group_id: 用户组ID用于获取配置
end_user_id: 用户组ID用于获取配置
Returns:
筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if not config_id:
raise ValueError(
f"No memory_config_id found for group_id: {group_id}. "
f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration."
)
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db(
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
limit: int,
by_user: bool = False
) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args:
connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else:
query = (
"MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC "
"LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询
results = await connector.execute_query(
query,
id=group_id,
id=end_user_id,
limit=limit,
names_to_exclude=names_to_exclude
)
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Raises:
ValueError: 如果group_id未提供或为空
ValueError: 如果end_user_id未提供或为空
"""
# 验证group_id必须提供且不为空
if not group_id or not group_id.strip():
# 验证end_user_id必须提供且不为空
if not end_user_id or not end_user_id.strip():
raise ValueError(
"group_id is required. Please provide a valid group_id or user_id."
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
end_user_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date

View File

@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id)
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
@@ -50,7 +50,7 @@ WITH d,
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,

View File

@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
group_id: str,
end_user_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
This function mirrors the steps in main(), but starts from raw text contexts.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes.
end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
group_id=group_id,
end_user_id=end_user_id,
user_id="default_user",
apply_id="default_application",
)
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_group_id)
return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], group_id: str):
async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
group_id=group_id,
end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True

View File

@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
sample_size: int = 20,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None)
end_user_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default group_id if not provided
group_id = group_id or SELECTED_GROUP_ID
# Use default end_user_id if not provided
end_user_id = end_user_id or SELECTED_end_user_id
# Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}")
print(f" Group ID: {end_user_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
# Step 2: Extract conversations and ingest if needed
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n")
print(f" Group ID: {end_user_id}\n")
else:
print("💾 Checking database ingestion...")
try:
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...")
print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
group_id=group_id,
end_user_id=end_user_id,
reset=reset_group
)
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
try:
retrieved_info = await retrieve_relevant_information(
question=question,
group_id=group_id,
end_user_id=end_user_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
@@ -467,7 +467,7 @@ def main():
help="Number of QA pairs to evaluate"
)
parser.add_argument(
"--group_id",
"--end_user_id",
type=str,
default=None,
help="Database group ID for retrieval (uses default if not specified)"
@@ -516,7 +516,7 @@ def main():
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,

View File

@@ -556,7 +556,7 @@ async def run_enhanced_evaluation():
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
group_id="locomo_sk",
end_user_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重

View File

@@ -348,7 +348,7 @@ def select_and_format_information(
async def retrieve_relevant_information(
question: str,
group_id: str,
end_user_id: str,
search_type: str,
search_limit: int,
connector: Any,
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
Args:
question: Question to search for
group_id: Database group ID (identifies which conversation memory to search)
end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit
)
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
async def ingest_conversations_if_needed(
conversations: List[str],
group_id: str,
end_user_id: str,
reset: bool = False
) -> bool:
"""
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage
end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
group_id=group_id,
end_user_id=end_user_id,
save_chunk_output=True
)
return success

View File

@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
async def run_locomo_eval(
sample_size: int = 1,
group_id: str | None = None,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
@@ -262,7 +262,7 @@ async def run_locomo_eval(
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID
end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
@@ -340,7 +340,7 @@ async def run_locomo_eval(
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端
with get_db_context() as db:
@@ -405,7 +405,7 @@ async def run_locomo_eval(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
@@ -456,7 +456,7 @@ async def run_locomo_eval(
search_results = await search_graph(
connector=connector,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
@@ -486,7 +486,7 @@ async def run_locomo_eval(
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
@@ -524,7 +524,7 @@ async def run_locomo_eval(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -597,7 +597,7 @@ async def run_locomo_eval(
"dialogues": [
{
"uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""),
"end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
@@ -795,7 +795,7 @@ async def run_locomo_eval(
},
"samples": samples,
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
@@ -825,7 +825,7 @@ async def run_locomo_eval(
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
@@ -841,7 +841,7 @@ def main():
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -524,11 +524,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
# 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
try:
for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows:
results.extend(rows)
except Exception:
@@ -548,15 +548,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
"""
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids:
return []
try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or []
except Exception:
return []
@@ -566,18 +566,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit
"""
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体"""
try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern,
group_id=group_id,
end_user_id=end_user_id,
limit=limit)
return rows or []
except Exception:
@@ -624,7 +624,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test(
sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_3",
end_user_id: str = "longmemeval_zh_bak_3",
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -678,13 +678,13 @@ async def run_longmemeval_test(
contexts.extend(selected)
print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
if reset_group_before_ingest and group_id:
if reset_group_before_ingest and end_user_id:
try:
_tmp_conn = Neo4jConnector()
await _tmp_conn.delete_group(group_id)
print(f"🧹 已清空组 {group_id} 的历史图数据")
await _tmp_conn.delete_group(end_user_id)
print(f"🧹 已清空组 {end_user_id} 的历史图数据")
except Exception as _e:
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}")
print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
finally:
try:
await _tmp_conn.close()
@@ -696,7 +696,7 @@ async def run_longmemeval_test(
else:
await _ingest_fn(
contexts,
group_id,
end_user_id,
save_chunk_output=save_chunk_output,
save_chunk_output_path=save_chunk_output_path,
)
@@ -751,7 +751,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -796,7 +796,7 @@ async def run_longmemeval_test(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
chunks = search_results.get("chunks", [])
@@ -831,7 +831,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -849,7 +849,7 @@ async def run_longmemeval_test(
kw_res = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
if isinstance(kw_res, dict):
@@ -860,7 +860,7 @@ async def run_longmemeval_test(
# 时间推理问题的特殊处理
if is_temporal:
# 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities:
kw_entities.extend(time_entities)
# 添加时间相关关键词检索
@@ -870,7 +870,7 @@ async def run_longmemeval_test(
time_res = await search_graph(
connector=connector,
q=tk,
group_id=group_id,
end_user_id=end_user_id,
limit=2,
)
if isinstance(time_res, dict):
@@ -881,7 +881,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配
cn_tokens = _extract_cn_tokens(question)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities:
kw_entities.extend(alias_entities)
@@ -895,7 +895,7 @@ async def run_longmemeval_test(
except Exception:
pass
if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities:
kw_entities.extend(id_entities)
@@ -909,7 +909,7 @@ async def run_longmemeval_test(
sub_res = await search_graph(
connector=connector,
q=str(kw),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(sub_res, dict):
@@ -928,7 +928,7 @@ async def run_longmemeval_test(
opt_res = await search_graph(
connector=connector,
q=str(opt),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(opt_res, dict):
@@ -1010,7 +1010,7 @@ async def run_longmemeval_test(
kw_fallback = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=max(search_limit, 5),
)
fb_dialogs = kw_fallback.get("dialogues", []) or []
@@ -1224,7 +1224,7 @@ async def run_longmemeval_test(
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
},
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
@@ -1307,7 +1307,7 @@ def main():
result = asyncio.run(
run_longmemeval_test(
sample_size=sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
# 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
try:
for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows:
results.extend(rows)
except Exception:
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
"""
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids:
return []
try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or []
except Exception:
return []
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit
"""
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体"""
try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern,
group_id=group_id,
end_user_id=end_user_id,
limit=limit)
return rows or []
except Exception:
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
# 技术术语专门检索
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
"""专门搜索技术术语相关的实体"""
tech_entities = []
try:
# GPS相关
if any(term in question for term in ["GPS", "导航", "定位系统"]):
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit)
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
if gps_rows:
tech_entities.extend(gps_rows)
# 活动相关
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit)
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
if workshop_rows:
tech_entities.extend(workshop_rows)
# 时间顺序相关
if any(term in question for term in ["", "", "第一个"]):
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit)
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
if time_rows:
tech_entities.extend(time_rows)
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test(
sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_2",
end_user_id: str = "longmemeval_zh_bak_2",
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
)
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
dialogs = search_results.get("dialogues", [])
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
)
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
kw_res = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
if isinstance(kw_res, dict):
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
kw_entities = kw_res.get("entities", []) or []
# 技术术语专门检索
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2)
tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
if tech_entities:
kw_entities.extend(tech_entities)
# 时间推理问题的特殊处理
if is_temporal:
# 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities:
kw_entities.extend(time_entities)
# 添加时间相关关键词检索
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
time_res = await search_graph(
connector=connector,
q=tk,
group_id=group_id,
end_user_id=end_user_id,
limit=2,
)
if isinstance(time_res, dict):
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities:
kw_entities.extend(alias_entities)
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
except Exception:
pass
if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities:
kw_entities.extend(id_entities)
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
sub_res = await search_graph(
connector=connector,
q=str(kw),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(sub_res, dict):
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
opt_res = await search_graph(
connector=connector,
q=str(opt),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(opt_res, dict):
@@ -971,7 +971,7 @@ async def run_longmemeval_test(
kw_fallback = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=max(search_limit, 5),
)
fb_dialogs = kw_fallback.get("dialogues", []) or []
@@ -1199,7 +1199,7 @@ async def run_longmemeval_test(
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
},
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
@@ -1278,7 +1278,7 @@ def main():
result = asyncio.run(
run_longmemeval_test(
sample_size=sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
end_user_id = end_user_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id)
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用)
with get_db_context() as db:
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
@@ -298,7 +298,7 @@ def main():
load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
@@ -309,7 +309,7 @@ def main():
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -199,7 +199,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
async def run_memsciqa_test(
sample_size: int = 3,
group_id: str | None = None,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -217,7 +217,7 @@ async def run_memsciqa_test(
"""
# 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci"
end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底)
if not data_path:
@@ -283,7 +283,7 @@ async def run_memsciqa_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
@@ -292,7 +292,7 @@ async def run_memsciqa_test(
results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
@@ -500,7 +500,7 @@ async def run_memsciqa_test(
},
"samples": samples,
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
@@ -543,7 +543,7 @@ def main():
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -26,7 +26,7 @@ async def run(
dataset: str,
sample_size: int,
reset_group: bool,
group_id: str | None,
end_user_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
@@ -37,17 +37,17 @@ async def run(
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID
end_user_id = end_user_id or SELECTED_GROUP_ID
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(group_id)
await connector.delete_group(end_user_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -61,7 +61,7 @@ async def run(
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -75,7 +75,7 @@ async def run(
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -99,8 +99,8 @@ def main():
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
@@ -117,7 +117,7 @@ def main():
args.dataset,
args.sample_size,
args.reset_group,
args.group_id,
args.end_user_id,
args.judge_model,
args.search_limit,
args.context_char_budget,

View File

@@ -187,11 +187,11 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData):
"""
Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
"""
@@ -201,9 +201,9 @@ class ChunkerClient:
f"Dialogue {dialogue.ref_id} has no messages. "
f"Cannot generate chunks from empty dialogue."
)
dialogue.chunks = []
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
@@ -212,13 +212,13 @@ class ChunkerClient:
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
@@ -228,14 +228,14 @@ class ChunkerClient:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
@@ -260,7 +260,7 @@ class ChunkerClient:
},
)
dialogue.chunks.append(chunk)
# Validate we generated at least one chunk
if not dialogue.chunks:
raise ValueError(
@@ -268,7 +268,7 @@ class ChunkerClient:
f"All messages were either empty or too short. "
f"Messages count: {len(dialogue.context.msgs)}"
)
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph.
Attributes:
group_id: Group ID to filter search results (default: 'test')
end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results
user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3)
"""
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge
source: ID of the source node
target: ID of the target node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes:
id: Unique identifier for the node
name: Name of the node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective)
"""
id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the node.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")

View File

@@ -55,7 +55,7 @@ class Statement(BaseModel):
Attributes:
id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement
@@ -73,7 +73,7 @@ class Statement(BaseModel):
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
return []
def assign_group_id_to_statements(self) -> None:
"""Assign this dialog's group_id to all statements in all chunks.
"""Assign this dialog's end_user_id to all statements in all chunks.
This method updates statements that don't have a group_id set.
This method updates statements that don't have a end_user_id set.
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.group_id is None:
statement.group_id = self.group_id
if statement.end_user_id is None:
statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering
end_user_id: Group identifier for filtering
limit: Maximum number of results
include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}"
f"end_user_id={end_user_id}, limit={limit}, include={include}"
)
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search(
query_text: str,
search_type: str,
group_id: str | None,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
@@ -715,7 +716,7 @@ async def run_hybrid_search(
}
# Log the search query
log_search_query(query_text, search_type, group_id, limit, include)
log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector()
results = {}
@@ -732,7 +733,7 @@ async def run_hybrid_search(
search_graph(
connector=connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include
)
@@ -769,7 +770,7 @@ async def run_hybrid_search(
connector=connector,
embedder_client=embedder,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
)
@@ -916,9 +917,7 @@ async def run_hybrid_search(
async def search_by_temporal(
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -929,7 +928,7 @@ async def search_by_temporal(
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
connector = Neo4jConnector()
@@ -939,9 +938,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -950,9 +947,7 @@ async def search_by_temporal(
})
statements = await search_graph_by_temporal(
connector=connector,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -964,9 +959,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal(
connector=connector,
query_text=query_text,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
group_id: Optional[str] = "test",
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id(
connector=connector,
chunk_id=chunk_id,
group_id=group_id,
end_user_id=end_user_id,
limit=limit
)
return {"chunks": chunks}

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}')
# 获取end_user_id如果不存在则生成默认值
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}')
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
"""
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map)
"""
exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map:
canonical_map[key] = ent
id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map:
exact_merge_map[k] = {
"canonical_id": canonical.id,
"group_id": canonical.group_id,
"end_user_id": canonical.end_user_id,
"name": canonical.name,
"entity_type": canonical.entity_type,
"merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j]
# 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1
continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
)
except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志
try:
llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id
try:
disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
except Exception:
pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception:
pass
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
# 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
# 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
"""
group_id 分块,避免跨组实体在同一块,减少无效候选对
end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args:
nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
"""
groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes:
gid = getattr(e, "group_id", None)
gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理
# 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环
break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i]
b = entity_nodes[j]
# 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode(
id=row.get("id"),
name=row.get("name") or "",
group_id=row.get("group_id") or "",
end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
"""
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id,
connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
)

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
# 先探测 end_user_id决定报告写入策略
end_user_id: Optional[str] = None
for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None)
if group_id:
end_user_id = getattr(dd, "end_user_id", None)
if end_user_id:
break
# 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
try:
if group_id:
if end_user_id:
if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector,
group_id=group_id,
end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else:
print("Skip second-layer dedup: missing connector")
else:
print("Skip second-layer dedup: missing group_id")
print("Skip second-layer dedup: missing end_user_id")
except Exception as e:
print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content))
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data
chunk, end_user_id, dialogue_content = chunk_data
try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id
# 加载DataConfig
data_config = None
# 加载MemoryConfig
memory_config = None
if config_id:
try:
from app.db import SessionLocal
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
db = SessionLocal()
try:
data_config = DataConfigRepository.get_by_id(db, config_id)
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
finally:
db.close()
if data_config and not data_config.emotion_enabled:
if memory_config and not memory_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list]
except Exception as e:
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list]
else:
logger.info("未找到config_id跳过情绪提取")
return [{} for _ in dialog_data_list]
# 如果配置未启用情绪提取,直接返回空映射
if not data_config or not data_config.emotion_enabled:
if not memory_config or not memory_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
total_statements += 1
# 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
all_statements.append((statement, memory_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
emotion_service = EmotionExtractionService(
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
)
# 全局并行处理所有陈述句
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content,
chunk_embedding=chunk.chunk_embedding,
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge(
source=statement.id,
target=chunk.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1095,9 +1087,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1112,9 +1102,7 @@ class ExtractionOrchestrator:
source=statement.id,
target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1134,9 +1122,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate,
statement=statement.statement,
source_statement_id=statement.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1763,14 +1749,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
end_user_id: str = "group_1",
indices: Optional[List[int]] = None,
) -> List[DialogData]:
"""从测试数据生成分块对话
Args:
chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID
end_user_id: 组ID
indices: 要处理的数据索引列表(可选)
Returns:
@@ -1834,7 +1820,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData(
context=conversation_context,
ref_id=data['id'],
group_id=group_id,
end_user_id=end_user_id,
metadata=dialog_metadata,
)
@@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "default",
end_user_id: str = "default",
user_id: str = "default",
apply_id: str = "default",
indices: Optional[List[int]] = None,
@@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args:
chunker_strategy: 分块策略
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
apply_id: 应用ID
indices: 要处理的数据索引列表
@@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices,
)
# 设置 group_id, user_id, apply_id
# 设置 end_user_id
for dd in preprocessed_data:
dd.group_id = group_id
dd.user_id = user_id
dd.apply_id = apply_id
dd.end_user_id = end_user_id
# 步骤2: 语义剪枝
try:

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode(
id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id,
user_id=dialog.user_id,
apply_id=dialog.apply_id,
end_user_id=dialog.end_user_id,
user_id=dialog.end_user_id,
apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(),
expired_at=datetime(9999, 12, 31),

View File

@@ -82,12 +82,12 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
Args:
chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk
end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context
Returns:
@@ -158,7 +158,7 @@ class StatementExtractor:
temporal_info=temporal_type,
relevence_info=relevence_info,
chunk_id=chunk.id,
group_id=group_id,
end_user_id=end_user_id,
speaker=chunk_speaker,
)
@@ -184,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True
)
@@ -225,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n")
f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -298,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({
"dialog_id": dialog.ref_id,
"group_id": dialog.group_id,
"end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations,
"weak": weak_relations,
@@ -312,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n")
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n")
f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
)
except Exception:
pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try:
prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
)
except Exception:
pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
end_user_id: 组ID可选
current_time: 当前时间(可选)
Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
end_user_id=end_user_id,
current_time=current_time
)
tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args:
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
limit: 检查的最大节点数
Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
if end_user_id:
params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
SET n += $repair_data
RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
if end_user_id:
read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """
RETURN n.id as id,
n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
if end_user_id:
read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句
where_conditions = []
if group_id:
where_conditions.append("n.group_id = $group_id")
if end_user_id:
where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查
if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
if end_user_id:
update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
return result
except Exception as e:

View File

@@ -11,9 +11,10 @@ Functions:
import logging
from typing import Optional, Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
def load_actr_config_from_db(
db: Session,
config_id: Optional[int] = None
config_id: Optional[UUID] = None
) -> Dict[str, Any]:
"""
从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数,
从 PostgreSQL 的 memory_config 表读取配置参数,
并计算派生参数(如 forgetting_rate
Args:
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
# 从数据库加载配置
try:
repository = DataConfigRepository()
repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None:
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
def create_actr_calculator_from_config(
db: Session,
config_id: Optional[int] = None
config_id: Optional[UUID] = None
) -> ACTRCalculator:
"""
从数据库配置创建 ACTRCalculator 实例
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
ValueError: 如果指定的 config_id 不存在
Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)

View File

@@ -16,6 +16,7 @@ Classes:
import logging
from typing import Dict, Any, Optional
from uuid import UUID
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
@@ -66,10 +67,10 @@ class ForgettingScheduler:
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> Dict[str, Any]:
"""
@@ -77,7 +78,7 @@ class ForgettingScheduler:
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +108,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
end_user_id=end_user_id,
min_days_since_access=min_days_since_access
)
@@ -213,7 +214,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
entity_node = {
@@ -222,7 +223,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
# 融合节点
@@ -262,7 +263,7 @@ class ForgettingScheduler:
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
@@ -315,7 +316,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
@@ -323,7 +324,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
@@ -333,16 +334,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
if end_user_id:
end_user_id['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)

View File

@@ -13,6 +13,7 @@ Classes:
import logging
from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -90,7 +91,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
@@ -102,7 +103,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
@@ -136,8 +137,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
if end_user_id:
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """
RETURN s.id as statement_id,
@@ -159,8 +160,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -176,7 +177,7 @@ class ForgettingStrategy:
self,
statement_node: Dict[str, Any],
entity_node: Dict[str, Any],
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> str:
"""
@@ -247,8 +248,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 获取 end_user_id从 statement 或 entity 节点)
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容
summary_text = await self._generate_summary(
@@ -325,7 +326,7 @@ class ForgettingStrategy:
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
end_user_id: $end_user_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
@@ -423,7 +424,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
'end_user_id': end_user_id
}
try:
@@ -462,7 +463,7 @@ class ForgettingStrategy:
statement_text: str,
entity_name: str,
entity_type: str,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> str:
"""
@@ -527,7 +528,7 @@ class ForgettingStrategy:
statement_text, entity_name, entity_type
)
async def _get_llm_client(self, db, config_id: int):
async def _get_llm_client(self, db, config_id: UUID):
"""
从数据库获取 LLM 客户端
@@ -539,11 +540,11 @@ class ForgettingStrategy:
LLM 客户端实例,如果无法获取则返回 None
"""
try:
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置
repository = DataConfigRepository()
repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search(
query_text: str,
search_type: str = "hybrid",
group_id: str | None = None,
end_user_id: str | None = None,
apply_id: str | None = None,
user_id: str | None = None,
limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args:
query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤
end_user_id: 组ID过滤
apply_id: 应用ID过滤
user_id: 用户ID过滤
limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索
result = await strategy.search(
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search(
# self,
# query_text: str,
# group_id: Optional[str] = None,
# end_user_id: Optional[str] = None,
# limit: int = 50,
# include: Optional[List[str]] = None,
# **kwargs
@@ -86,7 +86,7 @@
# Args:
# query_text: 查询文本
# group_id: 可选的组ID过滤
# end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns:
# SearchResult: 搜索结果对象
# """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
# semantic_result = await self.semantic_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list,
# alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# error=str(e)
# )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph(
connector=self.connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self,
query_text: str,
search_type: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
**kwargs
) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
search_type: 搜索类型
group_id: 组ID
end_user_id: 组ID
limit: 结果限制
**kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = {
"query": query_text,
"search_type": search_type,
"group_id": group_id,
"end_user_id": end_user_id,
"limit": limit,
"timestamp": datetime.now().isoformat()
}

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector,
embedder_client=self.embedder_client,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [
"id",
"statement",
"group_id",
"end_user_id",
"chunk_id",
"created_at",
"expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
"""
EXCLUDE_FIELDS = {
"user_id",
"group_id",
"end_user_id",
"entity_type",
"connect_strength",
"relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self,
config_id: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
success: bool = True,
details: Optional[Dict[str, Any]] = None
):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args:
config_id: 配置 ID
user_id: 用户 ID可选
group_id: 组 ID可选
end_user_id: 组 ID可选
success: 是否成功
details: 详细信息(可选)
"""
result = "SUCCESS" if success else "FAILED"
msg = (
f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}"
)
if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self,
operation: str,
config_id: str,
group_id: str,
end_user_id: str,
success: bool = True,
duration: Optional[float] = None,
error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args:
operation: 操作类型WRITE, READ 等)
config_id: 配置 ID
group_id: 组 ID
end_user_id: 组 ID
success: 是否成功
duration: 操作耗时(秒)
error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED"
msg = (
f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}"
f"group={end_user_id} result={result}"
)
if duration is not None:
msg += f" duration={duration:.2f}s"

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
GROUP_KEY = "end_user_id"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = auto()

View File

@@ -26,7 +26,7 @@ logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID."""
if model_id is None:
return None
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> tuple[str, bool]:
"""Validate that a model exists and is active.
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
db: Session,
tenant_id: Optional[UUID] = None,
required: bool = False,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status.
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
def validate_embedding_model(
config_id: int,
config_id: UUID,
embedding_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
@@ -256,7 +256,7 @@ def validate_embedding_model(
def validate_llm_model(
config_id: int,
config_id: UUID,
llm_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,

View File

@@ -1,4 +1,5 @@
import uuid
from uuid import UUID
from pydantic import Field
from typing import Literal
@@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
...
)
config_id: int = Field(
config_id: UUID = Field(
...
)
@@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
...
)
config_id: int = Field(
config_id: UUID = Field(
...
)

View File

@@ -22,7 +22,7 @@ class MemoryReadNode(BaseNode):
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
group_id=end_user_id,
end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id),
search_switch=self.typed_config.search_switch,

View File

@@ -18,7 +18,7 @@ from .appshare_model import AppShare
from .release_share_model import ReleaseShare
from .conversation_model import Conversation, Message
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
from .data_config_model import DataConfig
from .memory_config_model import MemoryConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
@@ -57,7 +57,7 @@ __all__ = [
"ApiKey",
"ApiKeyLog",
"ApiKeyType",
"DataConfig",
"MemoryConfig",
"MultiAgentConfig",
"AgentInvocation",
"WorkflowConfig",

View File

@@ -1,88 +0,0 @@
import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class DataConfig(Base):
"""数据配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "data_config"
# 主键
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")
# 组织信息
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
group_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<DataConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -1,39 +1,88 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Model - Backward Compatibility
import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
This module provides backward compatibility for imports.
All classes have been moved to app.schemas.memory_config_schema.
DEPRECATED: Import from app.schemas.memory_config_schema instead.
"""
class MemoryConfig(Base):
"""记忆配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "memory_config"
# Re-export for backward compatibility
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
MemoryConfigValidation,
ModelInactiveError,
ModelNotFoundError,
ModelValidation,
WorkspaceNotFoundError,
WorkspaceValidation,
validate_memory_config_data,
validate_model_data,
validate_workspace_data,
)
# 主键
config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
__all__ = [
"ConfigurationError",
"InvalidConfigError",
"MemoryConfig",
"MemoryConfigValidation",
"ModelInactiveError",
"ModelNotFoundError",
"ModelValidation",
"WorkspaceNotFoundError",
"WorkspaceValidation",
"validate_memory_config_data",
"validate_model_data",
"validate_workspace_data",
]
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")
# 组织信息
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
end_user_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<MemoryConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -16,7 +16,7 @@ class PerceptualType(IntEnum):
CONVERSATION = 4
class FileStorageType(IntEnum):
class FileStorageService(IntEnum):
LOCAL = 1
REMOTE = 2

View File

@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-
"""数据配置Repository模块
"""记忆配置Repository模块
本模块提供data_config表的数据访问层使用SQLAlchemy ORM进行数据库操作
本模块提供memory_config表的数据访问层使用SQLAlchemy ORM进行数据库操作
包括CRUD操作和Neo4j Cypher查询常量
Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作
MemoryConfigRepository: 记忆配置仓储类提供CRUD操作
"""
import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple
from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.models.memory_config_model import MemoryConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
@@ -28,11 +29,11 @@ db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "data_config"
class DataConfigRepository:
"""数据配置Repository
TABLE_NAME = "memory_config"
class MemoryConfigRepository:
"""记忆配置Repository
提供data_config表的数据访问方法包括
提供memory_config表的数据访问方法包括
- SQLAlchemy ORM 数据库操作
- Neo4j Cypher查询常量
"""
@@ -41,48 +42,48 @@ class DataConfigRepository:
# Dialogue count by group
SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# Chunk count by group
SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# Statement count by group
SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# ExtractedEntity count by group
SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# All counts by label and total
SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN 'Statement' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN n.entity_idx AS entity_idx,
n.connect_strength AS connect_strength,
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary,
n.group_id AS group_id,
n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.id AS id
@@ -91,9 +92,9 @@ class DataConfigRepository:
# Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN
r.group_id AS group_id,
r.end_user_id AS end_user_id,
r.apply_id AS apply_id,
r.user_id AS user_id,
elementId(r) AS rel_id,
@@ -107,7 +108,7 @@ class DataConfigRepository:
@staticmethod
def update_reflection_config(
db: Session,
config_id: int,
config_id: uuid.UUID,
enable_self_reflexion: bool,
iteration_period: str,
reflexion_range: str,
@@ -115,7 +116,7 @@ class DataConfigRepository:
reflection_model_id: str,
memory_verify: bool,
quality_assessment: bool
) -> DataConfig:
) -> MemoryConfig:
"""构建反思配置更新语句SQLAlchemy text() 命名参数)
Args:
@@ -130,28 +131,28 @@ class DataConfigRepository:
config_id: 配置ID
Returns:
Data
MemoryConfig
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config_obj = db.scalars(stmt).first()
if not data_config_obj:
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
memory_config_obj = db.scalars(stmt).first()
if not memory_config_obj:
raise BusinessException
data_config_obj.enable_self_reflexion = enable_self_reflexion
data_config_obj.iteration_period = iteration_period
data_config_obj.reflexion_range = reflexion_range
data_config_obj.baseline = baseline
data_config_obj.reflection_model_id = reflection_model_id
data_config_obj.memory_verify = memory_verify
data_config_obj.quality_assessment = quality_assessment
memory_config_obj.enable_self_reflexion = enable_self_reflexion
memory_config_obj.iteration_period = iteration_period
memory_config_obj.reflexion_range = reflexion_range
memory_config_obj.baseline = baseline
memory_config_obj.reflection_model_id = reflection_model_id
memory_config_obj.memory_verify = memory_verify
memory_config_obj.quality_assessment = quality_assessment
return data_config_obj
return memory_config_obj
@staticmethod
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig:
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
@@ -162,13 +163,13 @@ class DataConfigRepository:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config = db.scalars(stmt).first()
if not data_config:
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
memory_config = db.scalars(stmt).first()
if not memory_config:
raise RuntimeError("reflection config not found")
return data_config
return memory_config
@staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig:
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args:
@@ -180,11 +181,11 @@ class DataConfigRepository:
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id)
data_config = db.scalars(stmt).first()
if not data_config:
stmt = select(MemoryConfig).where(MemoryConfig.workspace_id == workspace_id)
memory_config = db.scalars(stmt).first()
if not memory_config:
raise RuntimeError("reflection config not found")
return data_config
return memory_config
@staticmethod
@@ -208,20 +209,21 @@ class DataConfigRepository:
return query, params
@staticmethod
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
"""创建数据配置
def create(db: Session, params: ConfigParamsCreate) -> MemoryConfig:
"""创建记忆配置
Args:
db: 数据库会话
params: 配置参数创建模型
Returns:
DataConfig: 创建的配置对象
MemoryConfig: 创建的配置对象
"""
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
db_logger.debug(f"创建记忆配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
try:
db_config = DataConfig(
db_config = MemoryConfig(
config_id=uuid.uuid4(),
config_name=params.config_name,
config_desc=params.config_desc,
workspace_id=params.workspace_id,
@@ -232,16 +234,16 @@ class DataConfigRepository:
db.add(db_config)
db.flush() # 获取自增ID但不提交事务
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
db_logger.info(f"记忆配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
db_logger.error(f"创建记忆配置失败: {params.config_name} - {str(e)}")
raise
@staticmethod
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
def update(db: Session, update: ConfigUpdate) -> Optional[MemoryConfig]:
"""更新基础配置
Args:
@@ -249,17 +251,17 @@ class DataConfigRepository:
update: 配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
db_logger.debug(f"更新记忆配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段
@@ -277,17 +279,17 @@ class DataConfigRepository:
db.commit()
db.refresh(db_config)
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
db_logger.info(f"记忆配置更新成功: {db_config.config_name} (ID: {update.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
"""更新记忆萃取引擎配置
Args:
@@ -295,7 +297,7 @@ class DataConfigRepository:
update: 萃取配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
@@ -303,9 +305,9 @@ class DataConfigRepository:
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段映射
@@ -360,7 +362,7 @@ class DataConfigRepository:
raise
@staticmethod
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[MemoryConfig]:
"""更新遗忘引擎配置
Args:
@@ -368,7 +370,7 @@ class DataConfigRepository:
update: 遗忘配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
@@ -376,9 +378,9 @@ class DataConfigRepository:
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段
@@ -408,7 +410,7 @@ class DataConfigRepository:
raise
@staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
@@ -421,7 +423,7 @@ class DataConfigRepository:
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
return None
@@ -457,7 +459,7 @@ class DataConfigRepository:
raise
@staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置
Args:
@@ -470,7 +472,7 @@ class DataConfigRepository:
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
return None
@@ -489,39 +491,39 @@ class DataConfigRepository:
raise
@staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
"""根据ID获取数据配置
def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]:
"""根据ID获取记忆配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[DataConfig]: 配置对象不存在则返回None
Optional[MemoryConfig]: 配置对象不存在则返回None
"""
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
db_logger.debug(f"根据ID查询记忆配置: config_id={config_id}")
try:
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config:
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
db_logger.debug(f"记忆配置查询成功: {config.config_name} (ID: {config_id})")
else:
db_logger.debug(f"数据配置不存在: config_id={config_id}")
db_logger.debug(f"记忆配置不存在: config_id={config_id}")
return config
except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
"""Get data config and its associated workspace information
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
"""Get memory config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
@@ -541,19 +543,19 @@ class DataConfigRepository:
}
)
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first()
result = db.query(MemoryConfig, Workspace).join(
Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(MemoryConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config_only:
if config_only.workspace_id is None:
config_logger.error(
@@ -566,7 +568,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} has no associated workspace ID")
db_logger.error(f"Memory config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace")
else:
config_logger.error(
@@ -579,7 +581,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
@@ -591,7 +593,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config not found: config_id={config_id}")
db_logger.debug(f"Memory config not found: config_id={config_id}")
return None
config, workspace = result
@@ -611,7 +613,7 @@ class DataConfigRepository:
}
)
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
@@ -633,10 +635,10 @@ class DataConfigRepository:
exc_info=True
)
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数
Args:
@@ -644,17 +646,17 @@ class DataConfigRepository:
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[DataConfig]: 配置列表
List[MemoryConfig]: 配置列表
"""
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(DataConfig)
query = db.query(MemoryConfig)
if workspace_id:
query = query.filter(DataConfig.workspace_id == workspace_id)
query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(DataConfig.updated_at)).all()
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
@@ -664,8 +666,8 @@ class DataConfigRepository:
raise
@staticmethod
def delete(db: Session, config_id: int) -> bool:
"""删除数据配置
def delete(db: Session, config_id: uuid.UUID) -> bool:
"""删除记忆配置
Args:
db: 数据库会话
@@ -674,22 +676,22 @@ class DataConfigRepository:
Returns:
bool: 删除成功返回True配置不存在返回False
"""
db_logger.debug(f"删除数据配置: config_id={config_id}")
db_logger.debug(f"删除记忆配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={config_id}")
db_logger.warning(f"记忆配置不存在: config_id={config_id}")
return False
db.delete(db_config)
db.commit()
db_logger.info(f"数据配置删除成功: config_id={config_id}")
db_logger.info(f"记忆配置删除成功: config_id={config_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
db_logger.error(f"删除记忆配置失败: config_id={config_id} - {str(e)}")
raise

View File

@@ -6,7 +6,7 @@ from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageService
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
db_logger = get_db_logger()
@@ -28,7 +28,7 @@ class MemoryPerceptualRepository:
file_ext: str,
summary: Optional[str] = None,
meta_data: Optional[dict] = None,
storage_service: FileStorageType = FileStorageType.LOCAL
storage_service: FileStorageService = FileStorageService.LOCAL
) -> MemoryPerceptualModel:

View File

@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"id": stable_edge_id,
"source": chunk.id,
"target": stmt.id,
"group_id": getattr(stmt, 'group_id', None),
"end_user_id": getattr(stmt, 'end_user_id', None),
"user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
"group_id": s.group_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully")
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"group_id": dialogue.group_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"end_user_id": dialogue.end_user_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
flattened_statement = {
"id": statement.id,
"name": statement.name,
"group_id": statement.group_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"end_user_id": statement.end_user_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
@@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"group_id": chunk.group_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"end_user_id": chunk.end_user_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
@@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
flattened.append({
"id": s.id,
"name": s.name,
"group_id": s.group_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... {"end_user_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""

View File

@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.end_user_id = dialogue.end_user_id,
n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at,
@@ -22,9 +20,7 @@ SET s += {
id: statement.id,
run_id: statement.run_id,
chunk_id: statement.chunk_id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
end_user_id: statement.end_user_id,
stmt_type: statement.stmt_type,
statement: statement.statement,
emotion_intensity: statement.emotion_intensity,
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
SET c += {
id: chunk.id,
name: chunk.name,
group_id: chunk.group_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
end_user_id: chunk.end_user_id,
run_id: chunk.run_id,
created_at: chunk.created_at,
expired_at: chunk.expired_at,
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
// Match entities by stable id within end_user_id, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id,
r.group_id = rel.group_id
r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid
"""
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
group_id: entity.group_id,
end_user_id: entity.end_user_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
@@ -175,11 +167,11 @@ RETURN e.id AS id
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.group_id = edge.group_id,
e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id,
SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
// Entities are shared across runs within end_user_id; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id,
SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id)
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id)
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id)
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
@@ -292,12 +283,12 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -316,15 +307,13 @@ LIMIT $limit
# 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id)
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
@@ -347,11 +336,11 @@ LIMIT $limit
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
@@ -413,10 +402,10 @@ LIMIT $limit
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id)
WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
AND d.id = $dialog_id
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at
@@ -426,10 +415,10 @@ LIMIT $limit
SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
AND c.id = $chunk_id
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.created_at AS created_at,
@@ -441,18 +430,14 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -468,9 +453,7 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -499,15 +480,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -519,15 +496,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -539,15 +512,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -559,15 +528,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -579,15 +544,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -599,15 +560,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -665,18 +622,18 @@ LIMIT $limit
# 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id})
MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id)
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id)
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
SET m += {
id: summary.id,
name: summary.name,
group_id: summary.group_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
end_user_id: summary.end_user_id,
run_id: summary.run_id,
created_at: summary.created_at,
expired_at: summary.expired_at,
@@ -745,7 +700,7 @@ MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.group_id = e.group_id,
SET r.end_user_id = e.end_user_id,
r.run_id = e.run_id,
r.created_at = e.created_at,
r.expired_at = e.expired_at
@@ -774,7 +729,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at,
invalid_at: rel.invalid_at,
group_id: rel.group_id,
end_user_id: rel.end_user_id,
user_id: rel.user_id,
apply_id: rel.apply_id,
run_id: rel.run_id,
@@ -796,7 +751,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at,
invalid_at: rel.invalid_at,
group_id: rel.group_id,
end_user_id: rel.end_user_id,
user_id: rel.user_id,
apply_id: rel.apply_id,
run_id: rel.run_id,
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
neo4j_statement_part = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN
n.statement as statement_name,
@@ -824,7 +779,7 @@ RETURN
'''
neo4j_statement_all = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
RETURN
n.statement as statement_name,
n.id as statement_id
@@ -832,7 +787,7 @@ RETURN
'''
neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
@@ -853,7 +808,7 @@ neo4j_query_part = """
"""
neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
@@ -1027,14 +982,14 @@ RETURN DISTINCT
Memory_Space_User="""
MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户"
WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id
"""
Memory_Space_Entity="""
MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN
DISTINCT m.name as name,m.group_id as group_id
DISTINCT m.name as name,m.end_user_id as end_user_id
"""
Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h)

View File

@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
提供按end_user_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据end_user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
return await self.find({"end_user_id": end_user_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
"""根据end_user_id和user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
{"end_user_id": end_user_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
end_user_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
days=days,
limit=limit
)
@@ -164,22 +164,22 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
end_user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
"""根据config_id和end_user_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"config_id": config_id, "group_id": group_id},
{"config_id": config_id, "end_user_id": end_user_id},
limit=limit
)

View File

@@ -40,7 +40,7 @@ class EmotionRepository:
async def get_emotion_tags(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
@@ -51,7 +51,7 @@ class EmotionRepository:
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串
@@ -65,8 +65,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -119,7 +119,7 @@ class EmotionRepository:
async def get_emotion_wordcloud(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> List[Dict[str, Any]]:
@@ -128,7 +128,7 @@ class EmotionRepository:
查询情绪关键词及其频率,用于生成词云可视化。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
@@ -140,8 +140,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -186,7 +186,7 @@ class EmotionRepository:
async def get_emotions_in_range(
self,
group_id: str,
end_user_id: str,
time_range: str = "30d"
) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据
@@ -194,7 +194,7 @@ class EmotionRepository:
查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d
Returns:
@@ -214,7 +214,7 @@ class EmotionRepository:
# 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """
MATCH (s:Statement)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date
RETURN s.id as statement_id,
@@ -227,7 +227,7 @@ class EmotionRepository:
try:
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
start_date=start_date
)
formatted_results = [

View File

@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'group_id': edge.group_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
'end_user_id': edge.end_user_id,
}
all_relationships.append(relationship)
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
"id": edge.id,
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
edge_data = {
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,

View File

@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选
end_user_id: 组ID可选
max_retries: 最大重试次数
Returns:
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
updated_nodes = await access_manager.record_batch_access(
node_ids=unique_node_ids,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
Args:
connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
connector=connector,
nodes=results[key],
node_label=label,
group_id=group_id
end_user_id=end_user_id
)
)
update_keys.append(key)
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
async def search_graph(
connector: Neo4jConnector,
q: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -236,7 +236,7 @@ async def search_graph(
Args:
connector: Neo4j connector
q: Query text
group_id: Optional group filter
end_user_id: Optional group filter
limit: Max results per category
include: List of categories to search (default: all)
@@ -254,7 +254,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -263,7 +263,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -272,7 +272,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -281,7 +281,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -310,12 +310,12 @@ async def search_graph(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -325,7 +325,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
@@ -337,7 +337,7 @@ async def search_graph_by_embedding(
- Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher
- Filters by group_id if provided
- Filters by end_user_id if provided
- Returns up to 'limit' per included type
"""
import time
@@ -346,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -361,7 +361,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -371,7 +371,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -381,7 +381,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -391,7 +391,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -400,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
@@ -429,13 +429,13 @@ async def search_graph_by_embedding(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
@@ -445,7 +445,7 @@ async def search_graph_by_embedding(
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
@@ -453,7 +453,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
@@ -477,7 +477,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name,
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
except Exception:
@@ -501,7 +501,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name.lower(),
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
for r in rows:
@@ -532,9 +532,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -547,32 +545,30 @@ async def search_graph_by_keyword_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
if not query_text:
logger.warning(f"query_text cannot be empty")
print(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -580,9 +576,7 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -595,14 +589,12 @@ async def search_graph_by_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
@@ -610,16 +602,16 @@ async def search_graph_by_temporal(
limit=limit,
)
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
logger.debug(f"Temporal search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -628,23 +620,23 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
- Matches dialogues with dialog_id
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' dialogues
"""
if not dialog_id:
logger.warning(f"dialog_id cannot be empty")
print(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id,
end_user_id=end_user_id,
dialog_id=dialog_id,
limit=limit,
)
@@ -654,15 +646,15 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
logger.warning(f"chunk_id cannot be empty")
print(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id,
end_user_id=end_user_id,
chunk_id=chunk_id,
limit=limit,
)
@@ -671,9 +663,9 @@ async def search_graph_by_chunk_id(
async def search_graph_by_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -683,37 +675,37 @@ async def search_graph_by_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -723,37 +715,37 @@ async def search_graph_by_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -763,37 +755,37 @@ async def search_graph_g_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -803,37 +795,37 @@ async def search_graph_g_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -843,37 +835,37 @@ async def search_graph_l_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -883,28 +875,28 @@ async def search_graph_l_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results

View File

@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository
Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges.
Provides methods to query summaries by end_user_id, user_id, and time ranges.
Attributes:
connector: Neo4j connector instance
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
return dict(n)
async def find_by_group_id(
async def find_by_end_user_id(
self,
group_id: str,
end_user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id
"""Query memory summaries by end_user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
"""
params = {"group_id": group_id, "limit": limit}
params = {"end_user_id": end_user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id
"""Query memory summaries by both end_user_id and user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id
WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
"""
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_recent_summaries(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 1000
) -> List[Dict[str, Any]]:
"""Query recent memory summaries
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
days: Number of recent days to query
limit: Maximum number of results to return
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
@@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
days=days,
limit=limit
)
@@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_content_keywords(
self,
group_id: str,
end_user_id: str,
keywords: List[str],
limit: int = 100
) -> List[Dict[str, Any]]:
"""Query memory summaries by content keywords
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
keywords: List of keywords to search for in content
limit: Maximum number of results to return
@@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
# Build keyword search conditions
keyword_conditions = []
params = {"group_id": group_id, "limit": limit}
params = {"end_user_id": end_user_id, "limit": limit}
for i, keyword in enumerate(keywords):
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
@@ -243,7 +243,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND ({keyword_filter})
RETURN n
ORDER BY n.created_at DESC
@@ -253,21 +253,21 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def get_summary_count_by_group(self, group_id: str) -> int:
async def get_summary_count_by_group(self, end_user_id: str) -> int:
"""Get count of memory summaries for a group
Args:
group_id: Group ID to count summaries for
end_user_id: Group ID to count summaries for
Returns:
int: Number of memory summaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
results = await self.connector.execute_query(query, group_id=group_id)
results = await self.connector.execute_query(query, end_user_id=end_user_id)
return results[0]['count'] if results else 0

View File

@@ -70,11 +70,7 @@ class Neo4jConnector:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example:
>>> connector = Neo4jConnector()
>>> results = await connector.execute_query(
... "MATCH (n:Person {name: $name}) RETURN n",
... name="Alice"
... )
"""
result = await self.driver.execute_query(
query,
@@ -98,17 +94,7 @@ class Neo4jConnector:
Any: 事务函数的返回值
Example:
>>> async def create_node(tx, name):
... result = await tx.run(
... "CREATE (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_write_transaction(
... create_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_write(transaction_func, **kwargs)
@@ -126,45 +112,33 @@ class Neo4jConnector:
Any: 事务函数的返回值
Example:
>>> async def get_node(tx, name):
... result = await tx.run(
... "MATCH (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_read_transaction(
... get_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str):
async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据
删除所有属于指定group_id的节点和边。
删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
group_id: 要删除的组ID
end_user_id: 要删除的组ID
Example:
>>> connector = Neo4jConnector()
>>> await connector.delete_group("group_123")
Group group_123 deleted.
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
print(f"Group {group_id} deleted.")
print(f"Group {end_user_id} deleted.")

View File

@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
Attributes:
connector: Neo4j连接器实例

View File

@@ -299,6 +299,18 @@ class AppRelease(BaseModel):
created_at: datetime.datetime
updated_at: datetime.datetime
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v if v is not None else {}
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None

View File

@@ -1,11 +1,12 @@
"""情绪分析相关的请求和响应模型"""
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
class EmotionTagsRequest(BaseModel):
"""获取情绪标签统计请求"""
group_id: str = Field(..., description="组ID")
end_user_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01")
end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31")
@@ -14,14 +15,14 @@ class EmotionTagsRequest(BaseModel):
class EmotionWordcloudRequest(BaseModel):
"""获取情绪词云数据请求"""
group_id: str = Field(..., description="组ID")
end_user_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
class EmotionHealthRequest(BaseModel):
"""获取情绪健康指数请求"""
group_id: str = Field(..., description="组ID")
end_user_id: str = Field(..., description="组ID")
time_range: str = Field("30d", description="时间范围7d/30d/90d")
@@ -29,8 +30,8 @@ class EmotionHealthRequest(BaseModel):
class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求"""
group_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型")
end_user_id: str = Field(..., description="组ID")
config_id: Optional[UUID] = Field(None, description="配置ID用于指定LLM模型")
class EmotionGenerateSuggestionsRequest(BaseModel):

View File

@@ -7,11 +7,11 @@ class UserInput(BaseModel):
message: str
history: list[dict]
search_switch: str
group_id: str
end_user_id: str
config_id: Optional[str] = None
class Write_UserInput(BaseModel):
messages: list[dict]
group_id: str
config_id: Optional[str] = None
end_user_id: str
config_id: Optional[str] = None

View File

@@ -35,7 +35,7 @@ class ConfigurationError(Exception):
def __init__(
self,
message: str,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
context: Optional[Dict[str, Any]] = None,
):
@@ -72,7 +72,7 @@ class WorkspaceNotFoundError(ConfigurationError):
def __init__(
self,
workspace_id: UUID,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
message: Optional[str] = None,
):
if message is None:
@@ -89,7 +89,7 @@ class ModelNotFoundError(ConfigurationError):
self,
model_id: Union[str, UUID],
model_type: str,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
@@ -112,7 +112,7 @@ class ModelInactiveError(ConfigurationError):
model_id: Union[str, UUID],
model_name: str,
model_type: str,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
@@ -136,7 +136,7 @@ class InvalidConfigError(ConfigurationError):
message: str,
field_name: Optional[str] = None,
invalid_value: Optional[Any] = None,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
):
context = {}
@@ -155,7 +155,7 @@ class InvalidConfigError(ConfigurationError):
class MemoryConfigValidation(BaseModel):
"""Pydantic model for validating memory configuration data from database."""
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
config_id: UUID = Field(..., description="Configuration ID (UUID)")
config_name: str = Field(..., min_length=1, max_length=255)
workspace_id: UUID = Field(..., description="Workspace UUID")
workspace_name: str = Field(..., min_length=1, max_length=255)
@@ -275,7 +275,7 @@ class ModelValidation(BaseModel):
def validate_memory_config_data(
config_data: Dict[str, Any], config_id: Optional[int] = None
config_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> MemoryConfigValidation:
"""Validate memory configuration data using Pydantic model."""
try:
@@ -302,7 +302,7 @@ def validate_memory_config_data(
def validate_workspace_data(
workspace_data: Dict[str, Any], config_id: Optional[int] = None
workspace_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> WorkspaceValidation:
"""Validate workspace data using Pydantic model."""
try:
@@ -331,7 +331,7 @@ def validate_workspace_data(
def validate_model_data(
model_data: Dict[str, Any], config_id: Optional[int] = None
model_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> ModelValidation:
"""Validate model data using Pydantic model."""
try:
@@ -364,7 +364,7 @@ def validate_model_data(
class MemoryConfig:
"""Immutable memory configuration loaded from database."""
config_id: int
config_id: UUID
config_name: str
workspace_id: UUID
workspace_name: str

View File

@@ -4,7 +4,7 @@ from typing import Optional
from pydantic import BaseModel, Field
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
class PerceptualFilter(BaseModel):
@@ -38,12 +38,14 @@ class PerceptualMemoryItem(BaseModel):
"""感知记忆项"""
id: uuid.UUID = Field(..., description="Unique memory ID")
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
storage_service: FileStorageService = Field(..., description="Storage service for file")
file_path: str = Field(..., description="File path in the storage service")
file_ext: str = Field(..., description="File extension")
file_name: str = Field(..., description="File name")
file_ext: str = Field(..., description="File extension")
summary: Optional[str] = Field(None, description="summary")
storage_type: FileStorageType = Field(..., description="Storage type for file")
meta_data: Optional[dict] = Field(None, description="Metadata information")
created_time: int = Field(..., description="create time")
topic: str = Field(..., description="topic")
domain: str = Field(..., description="domain")
keywords: list[str] = Field(..., description="keywords")

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import Optional
from uuid import UUID
from enum import Enum
@@ -9,7 +10,7 @@ class OptimizationStrategy(str, Enum):
ACCURACY_FIRST = "accuracy_first"
BALANCED = "balanced"
class Memory_Reflection(BaseModel):
config_id: Optional[int] = None
config_id: Optional[UUID] = None
reflection_enabled: bool
reflection_period_in_hours: str
reflexion_range: Optional[str] = "partial"

View File

@@ -1,5 +1,5 @@
"""
所有的内容是放错误地方了应该放在models
"""
from typing import Any, Optional, List, Dict, Literal, Union
@@ -8,20 +8,8 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================
# 原 UserInput 相关 Schema (保留原有功能)
# ============================================================================
class UserInput(BaseModel):
message: str
history: list[dict]
search_switch: str
group_id: str
class Write_UserInput(BaseModel):
message: str
group_id: str
# ============================================================================
# 从 json_schema.py 迁移的 Schema
@@ -159,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field("config_id", description="配置唯一标识(字符串")
config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID")
user_id: str = Field("user_id", description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
@@ -250,17 +238,17 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id: int = Field("配置ID", description="配置ID字符串")
config_id: uuid.UUID = Field("配置ID", description="配置IDUUID")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None
config_id: Optional[uuid.UUID] = None
config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None
config_id: Optional[uuid.UUID] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
@@ -327,14 +315,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型
config_id: Optional[int] = None
config_id: Optional[uuid.UUID] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id: int = Field(..., description="配置ID唯一")
config_id: uuid.UUID = Field(..., description="配置ID唯一")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -342,7 +330,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Optional[int] = None
config_id: Optional[uuid.UUID] = None
user_id: Optional[str] = None
apply_id: Optional[str] = None
@@ -418,7 +406,7 @@ class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID")
config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数")
lambda_mem: float = Field(..., description="记忆衰减参数")
@@ -436,7 +424,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID")
config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -511,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[int] = Field(None, description="配置ID可选如果为None则使用默认配置")
config_id: Optional[uuid.UUID] = Field(None, description="配置ID可选如果为None则使用默认配置")
class ForgettingCurveResponse(BaseModel):

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from pydantic import BaseModel, Field, field_serializer, field_validator, ConfigDict
from typing import Optional, List, Dict, Any
import datetime
import uuid
@@ -91,6 +91,18 @@ class ModelApiKey(ModelApiKeyBase):
created_at: datetime.datetime
updated_at: datetime.datetime
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None

View File

@@ -1,7 +1,7 @@
import uuid
import datetime
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, ConfigDict, field_serializer
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
# ---------- Input Schemas ----------
@@ -88,6 +88,18 @@ class SharedReleaseInfo(BaseModel):
# 嵌入配置
allow_embed: bool
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""处理 config 字段,如果是字符串则解析为字典"""
if isinstance(v, str):
import json
try:
return json.loads(v)
except json.JSONDecodeError:
return {}
return v if v is not None else {}
class EmbedCode(BaseModel):
"""嵌入代码"""

View File

@@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
end_user_id=end_user_id,
message=question,
history=[],
search_switch="2",

View File

@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags(
group_id=end_user_id,
end_user_id=end_user_id,
emotion_type=emotion_type,
start_date=start_date,
end_date=end_date,
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud(
group_id=end_user_id,
end_user_id=end_user_id,
emotion_type=emotion_type,
limit=limit
)
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
# 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id,
end_user_id=end_user_id,
time_range=time_range
)
@@ -505,7 +505,7 @@ class EmotionAnalyticsService:
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
config_id=(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
# 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id,
end_user_id=end_user_id,
time_range="30d"
)
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
# 查询用户的实体和标签
query = """
MATCH (e:Entity)
WHERE e.group_id = $group_id
WHERE e.end_user_id = $end_user_id
RETURN e.name as name, e.type as type
ORDER BY e.created_at DESC
LIMIT 20
"""
entities = await connector.execute_query(query, group_id=end_user_id)
entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]

Some files were not shown because too many files have changed in this diff Show More