Merge branch 'refs/heads/develop' into fix/memory_bug_fix

This commit is contained in:
lixinyue
2026-01-21 20:35:04 +08:00
100 changed files with 1783 additions and 1546 deletions

View File

@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Acknowledgements & Community
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
## Community & Support
Join our community to ask questions, share your work, and connect with fellow developers.
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6)
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

0
api/app/__init__.py Normal file
View File

View File

@@ -1,4 +1,5 @@
import os
import platform
from datetime import timedelta
from urllib.parse import quote
@@ -14,27 +15,12 @@ celery_app = Celery(
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
@@ -52,36 +38,47 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# 结果过期时间
result_expires=3600, # 结果保存 1 小时
result_expires=3600, # 结果保存1小时
# 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
},
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
# 构建定时任务配置
beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,

View File

@@ -9,6 +9,8 @@ from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
@@ -123,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Response with write operation status
@@ -158,11 +160,11 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
user_input.end_user_id,
user_input.messages,
config_id,
db,
storage_type,
@@ -191,7 +193,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Task ID for tracking async operation
@@ -221,7 +223,7 @@ async def write_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
args=[user_input.end_user_id, user_input.message, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
@@ -247,7 +249,7 @@ async def read_server(
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and group_id
user_input: Read request with message, history, search_switch, and end_user_id
Returns:
Response with query answer
@@ -271,12 +273,13 @@ async def read_server(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.group_id,
user_input.end_user_id,
user_input.message,
user_input.history,
user_input.search_switch,
@@ -285,6 +288,19 @@ async def read_server(
storage_type,
user_rag_memory_id
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -382,7 +398,7 @@ async def read_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
@@ -426,7 +442,7 @@ async def get_read_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -503,7 +519,7 @@ async def get_write_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -557,15 +573,30 @@ async def status_type(
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and group_id
user_input: Request containing user message and end_user_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
user_input.message,
user_input.messages,
user_input.config_id,
db
)
@@ -588,7 +619,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")

View File

@@ -5,7 +5,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
@@ -40,54 +39,7 @@ def get_workspace_total_end_users(
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.post("/update/end_users", response_model=ApiResponse)
async def update_workspace_end_users(
user_input: End_User_Information,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新工作空间的宿主信息
"""
username = user_input.end_user_name # 要更新的用户名
end_user_input_id = user_input.id # 宿主ID
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
try:
# 导入更新函数
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
# 转换 end_user_id 为 UUID 类型
end_user_uuid = uuid.UUID(end_user_input_id)
# 直接更新数据库中的 other_name 字段
updated_count = update_end_user_other_name(
db=db,
end_user_id=end_user_uuid,
other_name=username
)
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
return success(
data={
"updated_count": updated_count,
"end_user_id": end_user_input_id,
"updated_other_name": username
},
msg=f"成功更新 {updated_count} 个宿主的信息"
)
except Exception as e:
api_logger.error(f"更新宿主信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新宿主信息失败: {str(e)}"
)

View File

@@ -28,7 +28,6 @@ from app.services.memory_storage_service import (
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from fastapi import APIRouter, Depends
@@ -412,21 +411,7 @@ async def search_entity_edges(
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db)
@@ -50,6 +50,7 @@ async def write_memory_api_service(
config_id=payload.config_id,
storage_type=payload.storage_type,
user_rag_memory_id=payload.user_rag_memory_id,
tenant_id=api_key_auth.tenant_id,
)
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")

View File

@@ -351,12 +351,11 @@ async def update_end_user_profile(
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 检查用户是否已选择工作空间
# 验证工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
@@ -366,57 +365,24 @@ async def update_end_user_profile(
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
# 调用 Service 层处理业务逻辑
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
if result["success"]:
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return success(data=result["data"], msg="更新成功")
else:
error_msg = result["error"]
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
# 根据错误类型映射到合适的业务错误码
if error_msg == "终端用户不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
elif error_msg == "无效的用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
else:
# 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",

View File

@@ -154,7 +154,7 @@ class LangChainAgent:
userid=end_user_end,
messages=messages,
apply_id=end_user_end,
group_id=end_user_end,
end_user_id=end_user_end,
aimessages=aimessages
)
store.delete_duplicate_sessions()
@@ -173,16 +173,67 @@ class LangChainAgent:
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
user_rag_memory_id)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def chat(
self,
@@ -227,29 +278,30 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -277,8 +329,10 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
await self.term_memory_save(message_chat,end_user_id,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
response = {
"content": content,
"model": self.model_name,
@@ -346,27 +400,27 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -418,8 +472,10 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
if memory_flag:
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
await self.term_memory_save(message_chat, end_user_id, full_content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -18,28 +18,31 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
@@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
sentence=content,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
structured = await problem_service.call_llm_structured(
@@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
@@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
split_result_dict = []
for index, item in enumerate(json.loads(split_result)):
split_data = {
"id": f"Q{index+1}",
"id": f"Q{index + 1}",
"question": item['extended_question'],
"type": item['type'],
"reason": item['reason']
}
split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = {
@@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"original_query": content
}
}
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
@@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
result = {
"context": json.dumps([], ensure_ascii=False),
@@ -126,17 +129,18 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error": error_details
}
}
# 返回更新后的状态包含spit_context字段
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
@@ -152,11 +156,11 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {}
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
@@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState:
}
}
return {"problem_extension": result}
return {"problem_extension": result}

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config
async def rag_knowledge(state,question):
kb_config = await rag_config(state)
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '')
end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": question,
"return_raw_results": True
}
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id
# 从state中获取end_user_id
import time
start=time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(group_id)
search_params = { "group_id": group_id, "return_raw_results": True }
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题

View File

@@ -4,12 +4,11 @@ import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
@@ -35,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id)
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -123,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '')
group_id = state.get("group_id", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
user_id=group_id,
user_id=end_user_id,
query=data,
apply_id=group_id,
group_id=group_id,
apply_id=end_user_id,
end_user_id=end_user_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
@@ -176,13 +175,14 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
group_id=state.get("group_id", '')
end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
}
try:

View File

@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})

View File

@@ -9,26 +9,21 @@ async def write_node(state: WriteState) -> WriteState:
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: End user identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
content=state.get('data','')
group_id=state.get('group_id','')
end_user_id=state.get('end_user_id','')
memory_config=state.get('memory_config', '')
try:
result=await write(
content=content,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
end_user_id=end_user_id,
memory_config=memory_config,
messages=content, # 修复:使用正确的参数名 messages
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -59,7 +59,6 @@ async def make_read_graph():
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
@@ -80,7 +79,7 @@ async def make_read_graph():
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -96,9 +95,9 @@ async def main():
start=time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息
_intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str):
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
"""
def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
# 使用传入的参数或默认值
actual_group_id = group_id_param or group_id
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
results = await search_by_temporal(
group_id=actual_group_id,
end_user_id=actual_end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索
results = await search_by_keyword_temporal(
query_text=context,
group_id=group_id,
end_user_id=end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args:
memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等
**search_params: 搜索参数,包含end_user_id, limit, include等
"""
def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = {
"query_text": context,
"search_type": search_type,
"group_id": group_id or search_params.get("group_id"),
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
"""
async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context,
"search_type": search_type,
"limit": limit,
"group_id": group_id,
"end_user_id": end_user_id,
"clean_output": clean_output
})

View File

@@ -31,7 +31,7 @@ async def make_write_graph():
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
@@ -49,7 +49,7 @@ async def make_write_graph():
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
group_id = 'new_2025test1103' # 组ID
end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话
@@ -61,9 +61,9 @@ async def main():
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(

View File

@@ -162,7 +162,7 @@ class OptimizedLLMService:
return fallback_value
elif isinstance(fallback_value, dict):
return response_model(**fallback_value)
# 尝试创建空的响应模型
if hasattr(response_model, 'root'):
# RootModel类型
@@ -170,7 +170,7 @@ class OptimizedLLMService:
else:
# 普通BaseModel类型
return response_model()
except Exception as e:
logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
"end_user_id": end_user_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search(
self,
group_id: str,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
end_user_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True
)
# Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
end_user_id: str = "group_1",
content: str = "这是用户的输入",
ref_id: str = "wyl_20251027",
config_id: str = None
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
end_user_id: End user identifier
content: Dialog content
ref_id: Reference identifier
config_id: Configuration ID for processing
@@ -37,13 +33,11 @@ async def get_chunked_dialogs(
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
# Create DialogData with end_user_id
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
end_user_id=end_user_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue

View File

@@ -12,13 +12,11 @@ class WriteState(TypedDict):
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object
write_result: dict
data:str
data: str
class ReadState(TypedDict):
"""
@@ -28,7 +26,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加
loop_count: 遍历次数
search_switch: 搜索类型开关
group_id: 组标识
end_user_id: 组标识
config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果
@@ -39,7 +37,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int
search_switch: str
group_id: str
end_user_id: str
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"group_id": session.get('group_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id):
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
"""
import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
group_id = data.get('group_id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen:
# 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -29,25 +29,18 @@ logger = get_agent_logger(__name__)
async def write(
content: str,
user_id: str,
apply_id: str,
group_id: str,
end_user_id: str,
memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args:
content: Dialogue content to process
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: End user identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027"
"""
# Extract config values
@@ -61,7 +54,7 @@ async def write(
logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
logger.info(f"End User ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session
with get_db_context() as db:
@@ -84,12 +77,25 @@ async def write(
# Step 1: Load and chunk data
step_start = time.time()
# Convert messages list to content string
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
if isinstance(messages, list) and len(messages) > 0:
# Extract content from the last user message or concatenate all messages
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
content = messages[-1]['content']
else:
# Fallback: concatenate all message contents
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
elif isinstance(messages, str):
content = messages
else:
content = str(messages)
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
end_user_id=end_user_id,
content=content, # 修复:使用 content 参数而不是 messages
ref_id=ref_id,
config_id=config_id,
)

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args:
tags: 原始标签列表
group_id: 用户组ID用于获取配置
end_user_id: 用户组ID用于获取配置
Returns:
筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if not config_id:
raise ValueError(
f"No memory_config_id found for group_id: {group_id}. "
f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration."
)
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db(
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
limit: int,
by_user: bool = False
) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args:
connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else:
query = (
"MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC "
"LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询
results = await connector.execute_query(
query,
id=group_id,
id=end_user_id,
limit=limit,
names_to_exclude=names_to_exclude
)
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Raises:
ValueError: 如果group_id未提供或为空
ValueError: 如果end_user_id未提供或为空
"""
# 验证group_id必须提供且不为空
if not group_id or not group_id.strip():
# 验证end_user_id必须提供且不为空
if not end_user_id or not end_user_id.strip():
raise ValueError(
"group_id is required. Please provide a valid group_id or user_id."
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
end_user_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date

View File

@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id)
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
@@ -50,7 +50,7 @@ WITH d,
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,

View File

@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
group_id: str,
end_user_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
This function mirrors the steps in main(), but starts from raw text contexts.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes.
end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
group_id=group_id,
end_user_id=end_user_id,
user_id="default_user",
apply_id="default_application",
)
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_group_id)
return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], group_id: str):
async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
group_id=group_id,
end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True

View File

@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
sample_size: int = 20,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None)
end_user_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default group_id if not provided
group_id = group_id or SELECTED_GROUP_ID
# Use default end_user_id if not provided
end_user_id = end_user_id or SELECTED_end_user_id
# Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}")
print(f" Group ID: {end_user_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
# Step 2: Extract conversations and ingest if needed
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n")
print(f" Group ID: {end_user_id}\n")
else:
print("💾 Checking database ingestion...")
try:
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...")
print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
group_id=group_id,
end_user_id=end_user_id,
reset=reset_group
)
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
try:
retrieved_info = await retrieve_relevant_information(
question=question,
group_id=group_id,
end_user_id=end_user_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
@@ -467,7 +467,7 @@ def main():
help="Number of QA pairs to evaluate"
)
parser.add_argument(
"--group_id",
"--end_user_id",
type=str,
default=None,
help="Database group ID for retrieval (uses default if not specified)"
@@ -516,7 +516,7 @@ def main():
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,

View File

@@ -555,7 +555,7 @@ async def run_enhanced_evaluation():
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
group_id="locomo_sk",
end_user_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重

View File

@@ -348,7 +348,7 @@ def select_and_format_information(
async def retrieve_relevant_information(
question: str,
group_id: str,
end_user_id: str,
search_type: str,
search_limit: int,
connector: Any,
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
Args:
question: Question to search for
group_id: Database group ID (identifies which conversation memory to search)
end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit
)
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
async def ingest_conversations_if_needed(
conversations: List[str],
group_id: str,
end_user_id: str,
reset: bool = False
) -> bool:
"""
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage
end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
group_id=group_id,
end_user_id=end_user_id,
save_chunk_output=True
)
return success

View File

@@ -30,7 +30,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
async def run_locomo_eval(
sample_size: int = 1,
group_id: str | None = None,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
@@ -262,7 +262,7 @@ async def run_locomo_eval(
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID
end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
@@ -340,7 +340,7 @@ async def run_locomo_eval(
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端
with get_db_context() as db:
@@ -405,7 +405,7 @@ async def run_locomo_eval(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
@@ -456,7 +456,7 @@ async def run_locomo_eval(
search_results = await search_graph(
connector=connector,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
@@ -486,7 +486,7 @@ async def run_locomo_eval(
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
@@ -524,7 +524,7 @@ async def run_locomo_eval(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -597,7 +597,7 @@ async def run_locomo_eval(
"dialogues": [
{
"uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""),
"end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
@@ -795,7 +795,7 @@ async def run_locomo_eval(
},
"samples": samples,
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
@@ -825,7 +825,7 @@ async def run_locomo_eval(
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
@@ -841,7 +841,7 @@ def main():
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -523,11 +523,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
# 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
try:
for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows:
results.extend(rows)
except Exception:
@@ -547,15 +547,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
"""
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids:
return []
try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or []
except Exception:
return []
@@ -565,18 +565,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit
"""
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体"""
try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern,
group_id=group_id,
end_user_id=end_user_id,
limit=limit)
return rows or []
except Exception:
@@ -623,7 +623,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test(
sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_3",
end_user_id: str = "longmemeval_zh_bak_3",
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -677,13 +677,13 @@ async def run_longmemeval_test(
contexts.extend(selected)
print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
if reset_group_before_ingest and group_id:
if reset_group_before_ingest and end_user_id:
try:
_tmp_conn = Neo4jConnector()
await _tmp_conn.delete_group(group_id)
print(f"🧹 已清空组 {group_id} 的历史图数据")
await _tmp_conn.delete_group(end_user_id)
print(f"🧹 已清空组 {end_user_id} 的历史图数据")
except Exception as _e:
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}")
print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
finally:
try:
await _tmp_conn.close()
@@ -695,7 +695,7 @@ async def run_longmemeval_test(
else:
await _ingest_fn(
contexts,
group_id,
end_user_id,
save_chunk_output=save_chunk_output,
save_chunk_output_path=save_chunk_output_path,
)
@@ -750,7 +750,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -795,7 +795,7 @@ async def run_longmemeval_test(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
chunks = search_results.get("chunks", [])
@@ -830,7 +830,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
@@ -848,7 +848,7 @@ async def run_longmemeval_test(
kw_res = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
if isinstance(kw_res, dict):
@@ -859,7 +859,7 @@ async def run_longmemeval_test(
# 时间推理问题的特殊处理
if is_temporal:
# 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities:
kw_entities.extend(time_entities)
# 添加时间相关关键词检索
@@ -869,7 +869,7 @@ async def run_longmemeval_test(
time_res = await search_graph(
connector=connector,
q=tk,
group_id=group_id,
end_user_id=end_user_id,
limit=2,
)
if isinstance(time_res, dict):
@@ -880,7 +880,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配
cn_tokens = _extract_cn_tokens(question)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities:
kw_entities.extend(alias_entities)
@@ -894,7 +894,7 @@ async def run_longmemeval_test(
except Exception:
pass
if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities:
kw_entities.extend(id_entities)
@@ -908,7 +908,7 @@ async def run_longmemeval_test(
sub_res = await search_graph(
connector=connector,
q=str(kw),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(sub_res, dict):
@@ -927,7 +927,7 @@ async def run_longmemeval_test(
opt_res = await search_graph(
connector=connector,
q=str(opt),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(opt_res, dict):

View File

@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
# 通过别名匹配进行实体关键词检索多token合并
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
try:
for tok in tokens:
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
if rows:
results.extend(rows)
except Exception:
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
# 通过对话/陈述中的entity_ids反查实体名称
_FETCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity)
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
"""
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
if not ids:
return []
try:
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
return rows or []
except Exception:
return []
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
_TIME_ENTITY_SEARCH = """
MATCH (e:ExtractedEntity)
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
LIMIT $limit
"""
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
"""专门搜索时间相关的实体"""
try:
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
date_pattern=date_pattern,
group_id=group_id,
end_user_id=end_user_id,
limit=limit)
return rows or []
except Exception:
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
# 技术术语专门检索
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
"""专门搜索技术术语相关的实体"""
tech_entities = []
try:
# GPS相关
if any(term in question for term in ["GPS", "导航", "定位系统"]):
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit)
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
if gps_rows:
tech_entities.extend(gps_rows)
# 活动相关
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit)
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
if workshop_rows:
tech_entities.extend(workshop_rows)
# 时间顺序相关
if any(term in question for term in ["", "", "第一个"]):
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit)
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
if time_rows:
tech_entities.extend(time_rows)
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test(
sample_size: int = 3,
group_id: str = "longmemeval_zh_bak_2",
end_user_id: str = "longmemeval_zh_bak_2",
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
)
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
dialogs = search_results.get("dialogues", [])
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
)
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
kw_res = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
)
if isinstance(kw_res, dict):
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
kw_entities = kw_res.get("entities", []) or []
# 技术术语专门检索
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2)
tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
if tech_entities:
kw_entities.extend(tech_entities)
# 时间推理问题的特殊处理
if is_temporal:
# 专门搜索时间实体
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
if time_entities:
kw_entities.extend(time_entities)
# 添加时间相关关键词检索
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
time_res = await search_graph(
connector=connector,
q=tk,
group_id=group_id,
end_user_id=end_user_id,
limit=2,
)
if isinstance(time_res, dict):
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
# 中文关键词拆分后做别名匹配
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
if alias_entities:
kw_entities.extend(alias_entities)
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
except Exception:
pass
if ids:
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
if id_entities:
kw_entities.extend(id_entities)
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
sub_res = await search_graph(
connector=connector,
q=str(kw),
group_id=group_id,
end_user_id=end_user_id,
limit=max(3, search_limit // 2),
)
if isinstance(sub_res, dict):
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
opt_res = await search_graph(
connector=connector,
q=str(opt),
group_id=group_id,
end_user_id=group_id,
limit=max(3, search_limit // 2),
)
if isinstance(opt_res, dict):

View File

@@ -27,7 +27,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
end_user_id = end_user_id or SELECTED_end_user_id
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id)
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用)
with get_db_context() as db:
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
@@ -298,7 +298,7 @@ def main():
load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
@@ -309,7 +309,7 @@ def main():
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -33,7 +33,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -198,7 +198,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
async def run_memsciqa_test(
sample_size: int = 3,
group_id: str | None = None,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
@@ -216,7 +216,7 @@ async def run_memsciqa_test(
"""
# 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci"
end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底)
if not data_path:
@@ -282,7 +282,7 @@ async def run_memsciqa_test(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
@@ -291,7 +291,7 @@ async def run_memsciqa_test(
results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
@@ -499,7 +499,7 @@ async def run_memsciqa_test(
},
"samples": samples,
"params": {
"group_id": group_id,
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
@@ -542,7 +542,7 @@ def main():
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
group_id=args.group_id,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,

View File

@@ -15,7 +15,7 @@ except Exception:
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
from app.core.memory.utils.config.definitions import SELECTED_end_user_id, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
@@ -26,7 +26,7 @@ async def run(
dataset: str,
sample_size: int,
reset_group: bool,
group_id: str | None,
end_user_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
@@ -37,17 +37,17 @@ async def run(
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID
end_user_id = end_user_id or SELECTED_end_user_id
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(group_id)
await connector.delete_group(end_user_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -61,7 +61,7 @@ async def run(
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -75,7 +75,7 @@ async def run(
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
@@ -99,8 +99,8 @@ def main():
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
@@ -117,7 +117,7 @@ def main():
args.dataset,
args.sample_size,
args.reset_group,
args.group_id,
args.end_user_id,
args.judge_model,
args.search_limit,
args.context_char_budget,

View File

@@ -4,6 +4,7 @@ import os
import asyncio
import json
import numpy as np
import logging
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
# Initialize logger
logger = logging.getLogger(__name__)
class LLMChunker:
"""基于LLM的智能分块策略"""
"""LLM-based intelligent chunking strategy"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
Return results in JSON format with a chunks array, each chunk having a text field.
文本内容:
Text content:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
{"role": "user", "content": prompt}
]
@@ -171,8 +173,6 @@ class ChunkerClient:
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
@@ -186,100 +186,93 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
# Validate dialogue has messages
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Dialogue {dialogue.ref_id} has no messages. "
f"Cannot generate chunks from empty dialogue."
)
dialogue.chunks = []
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
raise ValueError(
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
try:
sub_chunks = self.chunker(msg_content)
except Exception as e:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
metadata={
"message_index": msg_idx,
"message_role": msg.role,
"sub_chunk_index": idx,
"total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
dialogue.chunks.append(chunk)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
# 消息不长,直接作为一个 chunk
chunk = Chunk(
content=f"{msg.role}: {msg_content}",
speaker=msg.role, # 直接继承角色
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"message_index": msg_idx,
"message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
dialogue.chunks.append(chunk)
# Validate we generated at least one chunk
if not dialogue.chunks:
raise ValueError(
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
f"All messages were either empty or too short. "
f"Messages count: {len(dialogue.context.msgs)}"
)
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
"""Evaluate chunking quality."""
if not getattr(dialogue, 'chunks', None):
return {}
@@ -304,11 +297,8 @@ class ChunkerClient:
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
"""Save chunking results to file with strategy name in filename."""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"

View File

@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config)
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
return response
except Exception as e:
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
config=config
)
logger.debug(f"使用 PydanticOutputParser 解析成功")
return parsed
except Exception as e:
logger.warning(
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
)
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
# 方法 2: 使用 LangChain 的 with_structured_output
template = """{question}"""
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
# 验证并返回结果
try:
return response_model.model_validate(parsed)
result = response_model.model_validate(parsed)
return result
except Exception:
# 如果已经是 Pydantic 实例,直接返回
if hasattr(parsed, "model_dump"):
return parsed
# 尝试从 JSON 解析
return response_model.model_validate_json(json.dumps(parsed))
result = response_model.model_validate_json(json.dumps(parsed))
return result
else:
logger.warning("with_structured_output 方法不可用")
except Exception as e:
logger.error(f"结构化输出失败: {e}")

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph.
Attributes:
group_id: Group ID to filter search results (default: 'test')
end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results
user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3)
"""
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge
source: ID of the source node
target: ID of the target node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes:
id: Unique identifier for the node
name: Name of the node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective)
"""
id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the node.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
@@ -224,6 +216,7 @@ class StatementNode(Node):
chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology)
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
emotion_target: Optional emotion target (person or object name)
emotion_subject: Optional emotion subject (self/other/object)
@@ -249,6 +242,12 @@ class StatementNode(Node):
stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content")
# Speaker identification
speaker: Optional[str] = Field(
None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
)
# Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field(
None,

View File

@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
"""Represents a single message in a conversation.
Attributes:
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
msg: Text content of the message
"""
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.")
@@ -55,8 +55,9 @@ class Statement(BaseModel):
Attributes:
id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
@@ -72,8 +73,9 @@ class Statement(BaseModel):
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
stmt_type: StatementType = Field(..., description="The type of the statement.")
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
Attributes:
id: Unique identifier for the chunk
text: List of messages in the chunk
content: The content of the chunk as a formatted string
speaker: The speaker/role for this chunk (user/assistant)
statements: List of statements extracted from this chunk
chunk_embedding: Optional embedding vector for the chunk
metadata: Additional metadata as key-value pairs
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a list of messages.
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a single message (1 Message = 1 Chunk).
Args:
messages: List of conversation messages
message: Single conversation message
metadata: Optional metadata dictionary
Returns:
Chunk instance with formatted content
Chunk instance with speaker directly from message.role
"""
if metadata is None:
metadata = {}
# Generate content from messages
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
return cls(text=messages, content=content, metadata=metadata)
return cls(
content=f"{message.role}: {message.msg}",
speaker=message.role,
metadata=metadata or {}
)
class DialogData(BaseModel):
"""Represents the complete data structure for a dialog record.
@@ -157,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs
@@ -173,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -254,5 +252,5 @@ class DialogData(BaseModel):
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.group_id is None:
statement.group_id = self.group_id
if statement.end_user_id is None:
statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering
end_user_id: Group identifier for filtering
limit: Maximum number of results
include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}"
f"end_user_id={end_user_id}, limit={limit}, include={include}"
)
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search(
query_text: str,
search_type: str,
group_id: str | None,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
@@ -692,6 +693,9 @@ async def run_hybrid_search(
# Start overall timing
search_start_time = time.time()
latency_metrics = {}
print(100*'-')
print(memory_config)
print(100 * '-')
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
# Clean and normalize the incoming query before use/logging
@@ -715,7 +719,7 @@ async def run_hybrid_search(
}
# Log the search query
log_search_query(query_text, search_type, group_id, limit, include)
log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector()
results = {}
@@ -732,7 +736,7 @@ async def run_hybrid_search(
search_graph(
connector=connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include
)
@@ -769,7 +773,7 @@ async def run_hybrid_search(
connector=connector,
embedder_client=embedder,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
)
@@ -916,9 +920,7 @@ async def run_hybrid_search(
async def search_by_temporal(
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -929,7 +931,7 @@ async def search_by_temporal(
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
connector = Neo4jConnector()
@@ -939,9 +941,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -950,9 +950,7 @@ async def search_by_temporal(
})
statements = await search_graph_by_temporal(
connector=connector,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -964,9 +962,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -987,9 +983,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -999,9 +993,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal(
connector=connector,
query_text=query_text,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -1013,7 +1005,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
group_id: Optional[str] = "test",
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
@@ -1023,8 +1015,68 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id(
connector=connector,
chunk_id=chunk_id,
group_id=group_id,
end_user_id=end_user_id,
limit=limit
)
return {"chunks": chunks}
if __name__ == '__main__':
# 测试混合检索功能
from app.schemas.memory_config_schema import MemoryConfig
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
# 从数据库获取真实配置
db = next(get_db())
try:
config_service = MemoryConfigService(db)
# 使用 config_id=17 获取配置
memory_config = config_service.load_memory_config(config_id=17)
if not memory_config:
print("错误:找不到 config_id=17 的配置")
print("请先在数据库中创建配置,或修改 config_id")
exit(1)
print(f"✓ 成功加载配置: {memory_config.config_name}")
print(f" - Workspace: {memory_config.workspace_name}")
print(f" - LLM Model: {memory_config.llm_model_name}")
print(f" - Embedding Model: {memory_config.embedding_model_name}")
print(f" - Storage Type: {memory_config.storage_type}")
print()
# 修改这里的参数进行测试
test_end_user_id = "021886bc-fab9-4fd5-b607-497b262e0381" # 修改为你的 end_user_id
test_query = "小明擅长什么?" # 修改为你的查询
print(f"开始测试检索...")
print(f" - Query: {test_query}")
print(f" - End User ID: {test_end_user_id}")
print(f" - Search Type: hybrid")
print()
results = asyncio.run(run_hybrid_search(
query_text=test_query,
search_type="hybrid", # 可选: "keyword", "embedding", "hybrid"
end_user_id=test_end_user_id,
limit=10,
include=["statements", "entities", "chunks", "summaries"],
output_path=None,
memory_config=memory_config,
rerank_alpha=0.6,
use_forgetting_rerank=False,
use_llm_rerank=False
))
print("=" * 80)
print("检索结果:")
print("=" * 80)
print(results)
except Exception as e:
print(f"错误: {e}")
import traceback
traceback.print_exc()
finally:
db.close()

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}')
# 获取end_user_id如果不存在则生成默认值
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}')
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
"""
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map)
"""
exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map:
canonical_map[key] = ent
id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map:
exact_merge_map[k] = {
"canonical_id": canonical.id,
"group_id": canonical.group_id,
"end_user_id": canonical.end_user_id,
"name": canonical.name,
"entity_type": canonical.entity_type,
"merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j]
# 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1
continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
)
except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志
try:
llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id
try:
disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
except Exception:
pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception:
pass
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
# 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
# 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
"""
group_id 分块,避免跨组实体在同一块,减少无效候选对
end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args:
nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
"""
groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes:
gid = getattr(e, "group_id", None)
gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理
# 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环
break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i]
b = entity_nodes[j]
# 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode(
id=row.get("id"),
name=row.get("name") or "",
group_id=row.get("group_id") or "",
end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
"""
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id,
connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
)

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
# 先探测 end_user_id决定报告写入策略
end_user_id: Optional[str] = None
for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None)
if group_id:
end_user_id = getattr(dd, "end_user_id", None)
if end_user_id:
break
# 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
try:
if group_id:
if end_user_id:
if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector,
group_id=group_id,
end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else:
print("Skip second-layer dedup: missing connector")
else:
print("Skip second-layer dedup: missing group_id")
print("Skip second-layer dedup: missing end_user_id")
except Exception as e:
print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content))
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data
chunk, end_user_id, dialogue_content = chunk_data
try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取情绪信息(优化版:全局陈述句级并行)
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
Args:
dialog_data_list: 对话数据列表
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
Returns:
情绪信息映射列表,每个对话对应一个字典
"""
logger.info("开始情绪信息提取(全局陈述句级并行")
logger.info("开始情绪信息提取(仅处理用户消息")
# 收集所有陈述句及其配置
all_statements = []
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
if not data_config or not data_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
# 收集所有陈述句(只收集 speaker 为 "user" 的)
total_statements = 0
filtered_statements = 0
# 收集所有陈述句
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
total_statements += 1
# 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
@@ -985,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1005,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content,
chunk_embedding=chunk.chunk_embedding,
@@ -1028,11 +1031,10 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
@@ -1052,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge(
source=statement.id,
target=chunk.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1087,9 +1087,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1104,9 +1102,7 @@ class ExtractionOrchestrator:
source=statement.id,
target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1126,9 +1122,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate,
statement=statement.statement,
source_statement_id=statement.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1755,14 +1749,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
end_user_id: str = "group_1",
indices: Optional[List[int]] = None,
) -> List[DialogData]:
"""从测试数据生成分块对话
Args:
chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID
end_user_id: 组ID
indices: 要处理的数据索引列表(可选)
Returns:
@@ -1826,7 +1820,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData(
context=conversation_context,
ref_id=data['id'],
group_id=group_id,
end_user_id=end_user_id,
metadata=dialog_metadata,
)
@@ -1928,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "default",
end_user_id: str = "default",
user_id: str = "default",
apply_id: str = "default",
indices: Optional[List[int]] = None,
@@ -1940,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args:
chunker_strategy: 分块策略
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
apply_id: 应用ID
indices: 要处理的数据索引列表
@@ -1968,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices,
)
# 设置 group_id, user_id, apply_id
# 设置 end_user_id
for dd in preprocessed_data:
dd.group_id = group_id
dd.user_id = user_id
dd.apply_id = apply_id
dd.end_user_id = end_user_id
# 步骤2: 语义剪枝
try:

View File

@@ -22,12 +22,12 @@ class DialogueChunker:
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# 对于 LLMChunker需要传入 llm_client
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
@@ -41,29 +41,19 @@ class DialogueChunker:
Returns:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
# Defensive fallback: ensure at least one chunk is returned for non-empty content
try:
chunks = result_dialogue.chunks
except Exception:
chunks = []
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# If the dialogue has content, return a single fallback chunk built from messages
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
if content_str and len(content_str.strip()) > 0:
fallback_chunk = Chunk.from_messages(
dialogue.context.msgs,
metadata={
"fallback": "single_chunk",
"chunker_strategy": self.chunker_config.chunker_strategy,
"source": "DialogueChunkerFallback",
},
)
return [fallback_chunk]
# No content: return empty list
return []
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
return chunks
@@ -72,22 +62,25 @@ class DialogueChunker:
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
output_path: Optional path to save the output
Returns:
The path where the output was saved
"""
if not output_path:
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt")
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = []
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
output_lines.append(f"Total characters: {len(dialogue.content)}")
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode(
id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id,
user_id=dialog.user_id,
apply_id=dialog.apply_id,
end_user_id=dialog.end_user_id,
user_id=dialog.end_user_id,
apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(),
expired_at=datetime(9999, 12, 31),

View File

@@ -5,8 +5,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text")
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
class StatementExtractionResponse(BaseModel):
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
return v
class StatementExtractor:
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
"""Class for extracting statements from dialog chunks using LLM"""
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
"""Initialize the StatementExtractor with an LLM client and configuration
Args:
@@ -71,21 +67,38 @@ class StatementExtractor:
self.llm_client = llm_client
self.config = config or StatementExtractionConfig()
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
"""Get speaker directly from Chunk
Args:
chunk: Chunk object containing speaker field
Returns:
Speaker role ("user"/"assistant") or None if cannot be determined
"""
if hasattr(chunk, 'speaker') and chunk.speaker:
return chunk.speaker
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
Args:
chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk
end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context
Returns:
List of ExtractedStatement objects extracted from the chunk
"""
# Prepare the chunk content for processing
chunk_content = chunk.content
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
return []
# Render the prompt using helper function
prompt_content = await render_statement_extraction_prompt(
chunk_content=chunk_content,
definitions=LABEL_DEFINITIONS,
@@ -136,15 +149,19 @@ class StatementExtractor:
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
except (KeyError, ValueError):
relevence_info = RelevenceInfo.RELEVANT
chunk_speaker = self._get_speaker_from_chunk(chunk)
chunk_statement = Statement(
statement=extracted_stmt.statement,
stmt_type=stmt_type,
temporal_info=temporal_type,
relevence_info=relevence_info,
chunk_id=chunk.id,
group_id=group_id,
end_user_id=end_user_id,
speaker=chunk_speaker,
)
chunk_statements.append(chunk_statement)
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
@@ -167,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True
)
@@ -208,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n")
f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -226,12 +243,7 @@ class StatementExtractor:
return output_path
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
"""按对话分组聚合强/弱关系并写入 TXT 文件。
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
"""
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
print("\n=== Relations Classify ===")
# 使用全局配置的输出路径
@@ -286,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({
"dialog_id": dialog.ref_id,
"group_id": dialog.group_id,
"end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations,
"weak": weak_relations,
@@ -300,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n")
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n")
f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
)
except Exception:
pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try:
prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
)
except Exception:
pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
end_user_id: 组ID可选
current_time: 当前时间(可选)
Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
end_user_id=end_user_id,
current_time=current_time
)
tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args:
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
limit: 检查的最大节点数
Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
if end_user_id:
params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
SET n += $repair_data
RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
if end_user_id:
read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """
RETURN n.id as id,
n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
if end_user_id:
read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句
where_conditions = []
if group_id:
where_conditions.append("n.group_id = $group_id")
if end_user_id:
where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查
if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
if end_user_id:
update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
return result
except Exception as e:

View File

@@ -66,7 +66,7 @@ class ForgettingScheduler:
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
@@ -77,7 +77,7 @@ class ForgettingScheduler:
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +107,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
end_user_id=end_user_id,
min_days_since_access=min_days_since_access
)
@@ -213,7 +213,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
entity_node = {
@@ -222,7 +222,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
# 融合节点
@@ -262,7 +262,7 @@ class ForgettingScheduler:
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
@@ -315,7 +315,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
@@ -323,7 +323,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
@@ -333,16 +333,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
if end_user_id:
end_user_id['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)

View File

@@ -90,7 +90,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
@@ -102,7 +102,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
@@ -136,8 +136,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
if end_user_id:
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """
RETURN s.id as statement_id,
@@ -159,8 +159,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -247,8 +247,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 获取 end_user_id从 statement 或 entity 节点)
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容
summary_text = await self._generate_summary(
@@ -325,7 +325,7 @@ class ForgettingStrategy:
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
end_user_id: $end_user_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
@@ -423,7 +423,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
'end_user_id': end_user_id
}
try:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search(
query_text: str,
search_type: str = "hybrid",
group_id: str | None = None,
end_user_id: str | None = None,
apply_id: str | None = None,
user_id: str | None = None,
limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args:
query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤
end_user_id: 组ID过滤
apply_id: 应用ID过滤
user_id: 用户ID过滤
limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索
result = await strategy.search(
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search(
# self,
# query_text: str,
# group_id: Optional[str] = None,
# end_user_id: Optional[str] = None,
# limit: int = 50,
# include: Optional[List[str]] = None,
# **kwargs
@@ -86,7 +86,7 @@
# Args:
# query_text: 查询文本
# group_id: 可选的组ID过滤
# end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns:
# SearchResult: 搜索结果对象
# """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
# semantic_result = await self.semantic_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list,
# alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# error=str(e)
# )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph(
connector=self.connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self,
query_text: str,
search_type: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
**kwargs
) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
search_type: 搜索类型
group_id: 组ID
end_user_id: 组ID
limit: 结果限制
**kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = {
"query": query_text,
"search_type": search_type,
"group_id": group_id,
"end_user_id": end_user_id,
"limit": limit,
"timestamp": datetime.now().isoformat()
}

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector,
embedder_client=self.embedder_client,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [
"id",
"statement",
"group_id",
"end_user_id",
"chunk_id",
"created_at",
"expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
"""
EXCLUDE_FIELDS = {
"user_id",
"group_id",
"end_user_id",
"entity_type",
"connect_strength",
"relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self,
config_id: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
success: bool = True,
details: Optional[Dict[str, Any]] = None
):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args:
config_id: 配置 ID
user_id: 用户 ID可选
group_id: 组 ID可选
end_user_id: 组 ID可选
success: 是否成功
details: 详细信息(可选)
"""
result = "SUCCESS" if success else "FAILED"
msg = (
f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}"
)
if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self,
operation: str,
config_id: str,
group_id: str,
end_user_id: str,
success: bool = True,
duration: Optional[float] = None,
error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args:
operation: 操作类型WRITE, READ 等)
config_id: 配置 ID
group_id: 组 ID
end_user_id: 组 ID
success: 是否成功
duration: 操作耗时(秒)
error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED"
msg = (
f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}"
f"group={end_user_id} result={result}"
)
if duration is not None:
msg += f" duration={duration:.2f}s"

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
GROUP_KEY = "end_user_id"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = auto()

View File

@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
start_time = time.time()
try:
# First check if model exists at all (without tenant filtering)
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
# Then check with tenant filtering
# OPTIMIZED: Single query with tenant filter
# We'll check tenant mismatch in the error handling
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
# Model not found with tenant filter - check if it exists without filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
if model_without_tenant:
# Model exists but belongs to different tenant
logger.warning(
@@ -208,8 +209,11 @@ def validate_embedding_model(
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that embedding model is available and return its UUID.
) -> tuple[UUID, str]:
"""Validate that embedding model is available and return its UUID and name.
Returns:
Tuple of (embedding_uuid, embedding_name)
Raises:
InvalidConfigError: If embedding_id is not provided or invalid
@@ -225,14 +229,19 @@ def validate_embedding_model(
workspace_id=workspace_id
)
embedding_uuid, _ = validate_and_resolve_model_id(
embedding_uuid, embedding_name = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
print(100*'-')
print(embedding_uuid)
print(_)
print(100*'-')
logger.debug(
"Embedding model validated",
extra={
"embedding_uuid": str(embedding_uuid),
"embedding_name": embedding_name,
"config_id": config_id
}
)
if embedding_uuid is None:
raise InvalidConfigError(
@@ -243,7 +252,7 @@ def validate_embedding_model(
workspace_id=workspace_id
)
return embedding_uuid
return embedding_uuid, embedding_name
def validate_llm_model(

View File

@@ -104,38 +104,6 @@ class DataConfigRepository:
r.statement AS statement
"""
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
{
entity_idx: n.entity_idx,
connect_strength: n.connect_strength,
description: n.description,
entity_type: n.entity_type,
name: n.name,
fact_summary: COALESCE(n.fact_summary, ''),
id: n.id
} AS sourceNode,
{
rel_id: elementId(r),
source_id: startNode(r).id,
target_id: endNode(r).id,
predicate: r.predicate,
statement_id: r.statement_id,
statement: r.statement
} AS edge,
{
entity_idx: m.entity_idx,
connect_strength: m.connect_strength,
description: m.description,
entity_type: m.entity_type,
name: m.name,
fact_summary: COALESCE(m.fact_summary, ''),
id: m.id
} AS targetNode
"""
@staticmethod
def update_reflection_config(
db: Session,

View File

@@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]
end_user = repo.get_end_user_by_id(end_user_id)
return end_user
def update_end_user_other_name(
db: Session,
end_user_id: uuid.UUID,
other_name: str
) -> int:
"""
通过 end_user_id 更新 end_user 表中的 other_name 字段
Args:
db: 数据库会话
end_user_id: 宿主ID
other_name: 要更新的用户名
Returns:
int: 更新的记录数
"""
try:
# 执行更新
updated_count = (
db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{EndUser.other_name: other_name},
synchronize_session=False
)
)
db.commit()
db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}")
return updated_count
except Exception as e:
db.rollback()
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
raise
# 新增的缓存操作函数(保持与类方法一致的接口)
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据ID获取终端用户用于缓存操作"""

View File

@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"id": stable_edge_id,
"source": chunk.id,
"target": stmt.id,
"group_id": getattr(stmt, 'group_id', None),
"end_user_id": getattr(stmt, 'end_user_id', None),
"user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
"group_id": s.group_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully")
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"group_id": dialogue.group_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"end_user_id": dialogue.end_user_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
flattened_statement = {
"id": statement.id,
"name": statement.name,
"group_id": statement.group_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"end_user_id": statement.end_user_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
@@ -101,6 +97,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
@@ -152,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"group_id": chunk.group_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"end_user_id": chunk.end_user_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
@@ -163,7 +159,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index")
"end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
}
flattened_chunks.append(flattened_chunk)
@@ -202,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
flattened.append({
"id": s.id,
"name": s.name,
"group_id": s.group_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... {"end_user_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""

View File

@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.end_user_id = dialogue.end_user_id,
n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at,
@@ -22,9 +20,7 @@ SET s += {
id: statement.id,
run_id: statement.run_id,
chunk_id: statement.chunk_id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
end_user_id: statement.end_user_id,
stmt_type: statement.stmt_type,
statement: statement.statement,
emotion_intensity: statement.emotion_intensity,
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
SET c += {
id: chunk.id,
name: chunk.name,
group_id: chunk.group_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
end_user_id: chunk.end_user_id,
run_id: chunk.run_id,
created_at: chunk.created_at,
expired_at: chunk.expired_at,
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
// Match entities by stable id within end_user_id, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id,
r.group_id = rel.group_id
r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid
"""
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
group_id: entity.group_id,
end_user_id: entity.end_user_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
@@ -175,11 +167,11 @@ RETURN e.id AS id
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.group_id = edge.group_id,
e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id,
SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
// Entities are shared across runs within end_user_id; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id,
SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id)
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id)
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id)
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
@@ -292,12 +283,12 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -316,15 +307,13 @@ LIMIT $limit
# 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id)
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
@@ -347,11 +336,11 @@ LIMIT $limit
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
@@ -413,10 +402,10 @@ LIMIT $limit
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id)
WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
AND d.id = $dialog_id
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at
@@ -426,10 +415,10 @@ LIMIT $limit
SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
AND c.id = $chunk_id
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.created_at AS created_at,
@@ -441,18 +430,14 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -468,9 +453,7 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -499,15 +480,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -519,15 +496,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -539,15 +512,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -559,15 +528,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -579,15 +544,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -599,15 +560,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -665,18 +622,18 @@ LIMIT $limit
# 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id})
MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id)
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id)
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
SET m += {
id: summary.id,
name: summary.name,
group_id: summary.group_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
end_user_id: summary.end_user_id,
run_id: summary.run_id,
created_at: summary.created_at,
expired_at: summary.expired_at,
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
neo4j_statement_part = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN
n.statement as statement_name,
@@ -824,7 +779,7 @@ RETURN
'''
neo4j_statement_all = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
RETURN
n.statement as statement_name,
n.id as statement_id
@@ -832,7 +787,7 @@ RETURN
'''
neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
@@ -853,7 +808,7 @@ neo4j_query_part = """
"""
neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
@@ -1027,14 +982,14 @@ RETURN DISTINCT
Memory_Space_User="""
MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户"
WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id
"""
Memory_Space_Entity="""
MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN
DISTINCT m.name as name,m.group_id as group_id
DISTINCT m.name as name,m.end_user_id as end_user_id
"""
Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h)

View File

@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
提供按end_user_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据end_user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
return await self.find({"end_user_id": end_user_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
"""根据end_user_id和user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
{"end_user_id": end_user_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
end_user_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
days=days,
limit=limit
)
@@ -164,16 +164,16 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
end_user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
"""根据config_id和end_user_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:

View File

@@ -40,7 +40,7 @@ class EmotionRepository:
async def get_emotion_tags(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
@@ -51,7 +51,7 @@ class EmotionRepository:
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串
@@ -65,8 +65,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -119,7 +119,7 @@ class EmotionRepository:
async def get_emotion_wordcloud(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> List[Dict[str, Any]]:
@@ -128,7 +128,7 @@ class EmotionRepository:
查询情绪关键词及其频率,用于生成词云可视化。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
@@ -140,8 +140,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -186,7 +186,7 @@ class EmotionRepository:
async def get_emotions_in_range(
self,
group_id: str,
end_user_id: str,
time_range: str = "30d"
) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据
@@ -194,7 +194,7 @@ class EmotionRepository:
查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d
Returns:
@@ -214,7 +214,7 @@ class EmotionRepository:
# 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """
MATCH (s:Statement)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date
RETURN s.id as statement_id,

View File

@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'group_id': edge.group_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
'end_user_id': edge.end_user_id,
}
all_relationships.append(relationship)
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
"id": edge.id,
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
edge_data = {
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,

View File

@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选
end_user_id: 组ID可选
max_retries: 最大重试次数
Returns:
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
updated_nodes = await access_manager.record_batch_access(
node_ids=unique_node_ids,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
Args:
connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
connector=connector,
nodes=results[key],
node_label=label,
group_id=group_id
end_user_id=end_user_id
)
)
update_keys.append(key)
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
async def search_graph(
connector: Neo4jConnector,
q: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -236,7 +236,7 @@ async def search_graph(
Args:
connector: Neo4j connector
q: Query text
group_id: Optional group filter
end_user_id: Optional group filter
limit: Max results per category
include: List of categories to search (default: all)
@@ -254,7 +254,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -263,7 +263,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -272,7 +272,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -281,7 +281,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -308,7 +308,7 @@ async def search_graph(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -318,7 +318,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
@@ -330,7 +330,7 @@ async def search_graph_by_embedding(
- Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher
- Filters by group_id if provided
- Filters by end_user_id if provided
- Returns up to 'limit' per included type
"""
import time
@@ -354,7 +354,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -364,7 +364,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -374,7 +374,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -384,7 +384,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -421,7 +421,7 @@ async def search_graph_by_embedding(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
@@ -429,7 +429,7 @@ async def search_graph_by_embedding(
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
@@ -437,7 +437,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
@@ -461,7 +461,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name,
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
except Exception:
@@ -485,7 +485,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name.lower(),
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
for r in rows:
@@ -516,9 +516,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -531,7 +529,7 @@ async def search_graph_by_keyword_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
if not query_text:
@@ -540,9 +538,7 @@ async def search_graph_by_keyword_temporal(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
@@ -556,7 +552,7 @@ async def search_graph_by_keyword_temporal(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -564,9 +560,7 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -579,14 +573,12 @@ async def search_graph_by_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
@@ -595,7 +587,7 @@ async def search_graph_by_temporal(
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -603,7 +595,7 @@ async def search_graph_by_temporal(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -612,14 +604,14 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
- Matches dialogues with dialog_id
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' dialogues
"""
if not dialog_id:
@@ -628,7 +620,7 @@ async def search_graph_by_dialog_id(
dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id,
end_user_id=end_user_id,
dialog_id=dialog_id,
limit=limit,
)
@@ -638,7 +630,7 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
@@ -646,7 +638,7 @@ async def search_graph_by_chunk_id(
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id,
end_user_id=end_user_id,
chunk_id=chunk_id,
limit=limit,
)
@@ -655,9 +647,9 @@ async def search_graph_by_chunk_id(
async def search_graph_by_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -667,20 +659,20 @@ async def search_graph_by_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -688,16 +680,16 @@ async def search_graph_by_created_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -707,20 +699,20 @@ async def search_graph_by_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -728,16 +720,16 @@ async def search_graph_by_valid_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -747,20 +739,20 @@ async def search_graph_g_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -768,16 +760,16 @@ async def search_graph_g_created_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -787,20 +779,20 @@ async def search_graph_g_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -808,16 +800,16 @@ async def search_graph_g_valid_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -827,20 +819,20 @@ async def search_graph_l_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -848,16 +840,16 @@ async def search_graph_l_created_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -867,20 +859,20 @@ async def search_graph_l_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
@@ -888,7 +880,7 @@ async def search_graph_l_valid_at(
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results

View File

@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository
Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges.
Provides methods to query summaries by end_user_id, user_id, and time ranges.
Attributes:
connector: Neo4j connector instance
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
return dict(n)
async def find_by_group_id(
async def find_by_end_user_id(
self,
group_id: str,
end_user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id
"""Query memory summaries by end_user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
"""
params = {"group_id": group_id, "limit": limit}
params = {"end_user_id": end_user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id
"""Query memory summaries by both end_user_id and user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id
WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
"""
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_recent_summaries(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 1000
) -> List[Dict[str, Any]]:
"""Query recent memory summaries
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
days: Number of recent days to query
limit: Maximum number of results to return
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC

View File

@@ -141,14 +141,14 @@ class Neo4jConnector:
async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str):
async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据
删除所有属于指定group_id的节点和边。
删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
group_id: 要删除的组ID
end_user_id: 要删除的组ID
Example:
>>> connector = Neo4jConnector()
@@ -157,14 +157,14 @@ class Neo4jConnector:
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
print(f"Group {group_id} deleted.")
print(f"Group {end_user_id} deleted.")

View File

@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
Attributes:
connector: Neo4j连接器实例

View File

@@ -7,15 +7,11 @@ class UserInput(BaseModel):
message: str
history: list[dict]
search_switch: str
group_id: str
end_user_id: str
config_id: Optional[str] = None
class Write_UserInput(BaseModel):
message: str
group_id: str
messages: list[dict]
end_user_id: str
config_id: Optional[str] = None
class End_User_Information(BaseModel):
end_user_name: str # 这是要更新的用户名
id: str # 宿主ID用于匹配条件

View File

@@ -10,11 +10,6 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -93,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
end_user_id=end_user_id,
message=question,
history=[],
search_switch="2",
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
# result = task_service.get_task_memory_read_result(task.id)
# status = result.get("status")
# logger.info(f"读取任务状态:{status}")
finally:
db.close()

View File

@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags(
group_id=end_user_id,
end_user_id=end_user_id,
emotion_type=emotion_type,
start_date=start_date,
end_date=end_date,
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
# 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud(
group_id=end_user_id,
end_user_id=end_user_id,
emotion_type=emotion_type,
limit=limit
)
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
# 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id,
end_user_id=end_user_id,
time_range=time_range
)
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
# 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range(
group_id=end_user_id,
end_user_id=end_user_id,
time_range="30d"
)
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
# 查询用户的实体和标签
query = """
MATCH (e:Entity)
WHERE e.group_id = $group_id
WHERE e.end_user_id = $end_user_id
RETURN e.name as name, e.type as type
ORDER BY e.created_at DESC
LIMIT 20
"""
entities = await connector.execute_query(query, group_id=end_user_id)
entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]

View File

@@ -10,27 +10,34 @@ import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
import redis
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_base_service import Translation_English
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -49,25 +56,24 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService:
"""Service for memory agent operations"""
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time
if str(messages) == 'success':
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
logger.warning(f"Write operation failed for group {group_id}")
logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作
if audit_logger:
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
@@ -260,12 +266,12 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier (also used as end_user_id)
end_user_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
@@ -281,15 +287,15 @@ class MemoryAgentService:
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
import time
start_time = time.time()
@@ -309,56 +315,103 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
try:
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
"memory_config": memory_config}
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
print(massages)
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# try:
# if storage_type == "rag":
# # For RAG storage, convert messages to single string
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
# return result
# else:
# async with make_write_graph() as graph:
# config = {"configurable": {"thread_id": end_user_id}}
# # Convert structured messages to LangChain messages
# langchain_messages = []
# for msg in messages:
# if msg['role'] == 'user':
# langchain_messages.append(HumanMessage(content=msg['content']))
# elif msg['role'] == 'assistant':
# langchain_messages.append(AIMessage(content=msg['content']))
#
# # 初始状态 - 包含所有必要字段
# initial_state = {
# "messages": langchain_messages,
# "end_user_id": end_user_id,
# "memory_config": memory_config
# }
#
# # 获取节点更新信息
# async for update_event in graph.astream(
# initial_state,
# stream_mode="updates",
# config=config
# ):
# for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
# massagesstatus = massages.get('write_result')['status']
# contents = massages.get('write_result')
# # Convert messages back to string for logging
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# except Exception as e:
# # Ensure proper error handling and logging
# error_msg = f"Write operation failed: {str(e)}"
# logger.error(error_msg)
# if audit_logger:
# duration = time.time() - start_time
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# raise ValueError(error_msg)
async def read_memory(
self,
group_id: str,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[str],
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
@@ -368,7 +421,7 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier (also used as end_user_id)
end_user_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
@@ -386,21 +439,22 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message= message
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
try:
@@ -426,7 +480,7 @@ class MemoryAgentService:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
@@ -436,15 +490,16 @@ class MemoryAgentService:
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"group_id": group_id
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
@@ -495,18 +550,72 @@ class MemoryAgentService:
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs)
# 保存短期记忆到数据库
# 只有 search_switch 不为 "2"(快速检索)时才保存
try:
from app.repositories.memory_short_repository import ShortTermMemoryRepository
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
reranked_results = raw_results.get('reranked_results', [])
try:
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id,
messages=message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
@@ -524,14 +633,56 @@ class MemoryAgentService:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
Args:
user_input: Write_UserInput object
Returns:
list[dict]: Message list, each message contains role and content
Raises:
ValueError: If messages is empty or format is incorrect
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
@@ -559,7 +710,67 @@ class MemoryAgentService:
logger.debug(f"Message type: {status}")
return status
# ==================== 新增的三个接口方法 ====================
async def generate_summary_from_retrieve(
self,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
Args:
retrieve_info: 检索到的信息
history: 历史对话记录
query: 用户查询
config_id: 配置ID
db: 数据库会话
Returns:
生成的答案文本
"""
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
history=history,
retrieve_info=retrieve_info,
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
@@ -571,7 +782,7 @@ class MemoryAgentService:
"""
统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
3. total: 所有类型的总和
参数:
@@ -657,11 +868,11 @@ class MemoryAgentService:
for end_user in end_users:
end_user_id_str = str(end_user.id)
memory_query = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
"""
neo4j_result = await _neo4j_connector.execute_query(
memory_query,
group_id=end_user_id_str,
end_user_id=end_user_id_str,
)
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
total_chunks += chunk_count
@@ -701,7 +912,7 @@ class MemoryAgentService:
获取指定用户的热门记忆标签
参数:
- end_user_id: 用户ID可选对应Neo4j中的group_id字段
- end_user_id: 用户ID可选对应Neo4j中的end_user_id字段
- limit: 返回标签数量限制
返回格式:
@@ -711,7 +922,7 @@ class MemoryAgentService:
]
"""
try:
# by_user=False 表示按 group_id 查询在Neo4j中group_id就是用户维度
# by_user=False 表示按 end_user_id 查询在Neo4j中end_user_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload=[]
for tag, freq in tags:
@@ -786,21 +997,21 @@ class MemoryAgentService:
# 查询该用户的语句
query = (
"MATCH (s:Statement) "
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL "
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
"RETURN s.statement AS statement "
"ORDER BY s.created_at DESC LIMIT 100"
)
rows = await connector.execute_query(query, group_id=end_user_id)
rows = await connector.execute_query(query, end_user_id=end_user_id)
statements = [r.get("statement", "") for r in rows if r.get("statement")]
# 查询该用户的热门实体
entity_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 20"
)
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id)
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
await connector.close()
@@ -853,14 +1064,14 @@ class MemoryAgentService:
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
hot_tag_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 4"
)
hot_tag_rows = await connector.execute_query(
hot_tag_query,
group_id=end_user_id,
end_user_id=end_user_id,
names_to_exclude=names_to_exclude
)
await connector.close()
@@ -1006,6 +1217,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"memory_config_id": memory_config_id
}
print(188*'*')
print(result)
print(188 * '*')
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result
@@ -1033,7 +1248,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from app.models.data_config_model import DataConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
@@ -1046,10 +1261,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建 end_user_id -> app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 记录未找到的用户
found_user_ids = set(user_to_app.keys())
missing_user_ids = set(end_user_ids) - found_user_ids
@@ -1097,7 +1312,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 4. 构建最终结果
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
if not release:
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}

View File

@@ -25,7 +25,7 @@ class MemoryAPIService:
This service provides a thin layer that:
1. Validates end_user exists and belongs to the authorized workspace
2. Maps end_user_id to group_id for memory operations
2. Maps end_user_id to end_user_id for memory operations
3. Delegates to MemoryAgentService for actual memory read/write operations
"""
@@ -68,7 +68,7 @@ class MemoryAPIService:
)
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ResourceNotFoundException(
@@ -115,7 +115,7 @@ class MemoryAPIService:
Args:
workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as group_id)
end_user_id: End user identifier (used as end_user_id)
message: Message content to store
config_id: Optional memory configuration ID
storage_type: Storage backend (neo4j or rag)
@@ -133,13 +133,12 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as group_id for memory operations
group_id = end_user_id
# Use end_user_id as end_user_id for memory operations
try:
# Delegate to MemoryAgentService
result = await MemoryAgentService().write_memory(
group_id=group_id,
end_user_id=end_user_id,
message=message,
config_id=config_id,
db=self.db,
@@ -186,7 +185,7 @@ class MemoryAPIService:
Args:
workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as group_id)
end_user_id: End user identifier (used as end_user_id)
message: Query message
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
config_id: Optional memory configuration ID
@@ -205,13 +204,13 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as group_id for memory operations
group_id = end_user_id
# Use end_user_id as end_user_id for memory operations
try:
# Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory(
group_id=group_id,
end_user_id=end_user_id,
message=message,
history=[],
search_switch=search_switch,

View File

@@ -326,7 +326,7 @@ class MemoryBaseService:
Args:
summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id)
end_user_id: 终端用户ID (end_user_id)
Returns:
最大emotion_intensity对应的emotion_type如果没有则返回None
@@ -334,7 +334,7 @@ class MemoryBaseService:
try:
query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
WHERE stmt.emotion_type IS NOT NULL
AND stmt.emotion_intensity IS NOT NULL
@@ -347,7 +347,7 @@ class MemoryBaseService:
result = await self.neo4j_connector.execute_query(
query,
summary_id=summary_id,
group_id=end_user_id
end_user_id=end_user_id
)
if result and len(result) > 0:
@@ -381,10 +381,10 @@ class MemoryBaseService:
if end_user_id:
query = """
MATCH (n:MemorySummary)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id)
result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id)
else:
query = """
MATCH (n:MemorySummary)
@@ -423,12 +423,12 @@ class MemoryBaseService:
if end_user_id:
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.group_id = $group_id AND e.is_explicit_memory = true
WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true
RETURN count(e) as count
"""
semantic_result = await self.neo4j_connector.execute_query(
semantic_query,
group_id=end_user_id
end_user_id=end_user_id
)
else:
semantic_query = """
@@ -519,7 +519,7 @@ class MemoryBaseService:
"""
if end_user_id:
query += " AND n.group_id = $group_id"
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
@@ -528,7 +528,7 @@ class MemoryBaseService:
# 设置查询参数
params = {'threshold': forgetting_threshold}
if end_user_id:
params['group_id'] = end_user_id
params['end_user_id'] = end_user_id
# 执行查询
result = await self.neo4j_connector.execute_query(query, **params)

View File

@@ -125,7 +125,11 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
# Step 1: Get config and workspace
db_query_start = time.time()
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -144,16 +148,20 @@ class MemoryConfigService:
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
# Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Resolve LLM model
# Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
@@ -163,8 +171,11 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Resolve optional rerank model
# Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
@@ -177,16 +188,12 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Note: embedding_name is now returned from validate_embedding_model above
# No need for redundant query!
# Create immutable MemoryConfig object
config = MemoryConfig(

View File

@@ -717,8 +717,8 @@ class MemoryInteraction:
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
if ori_data!=[]:
# name = ori_data[0]['name']
group_id = [i['group_id'] for i in ori_data][0]
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
end_user_id = [i['end_user_id'] for i in ori_data][0]
Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id)
if not Space_User:
return []
user_id=Space_User[0]['id']

View File

@@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args:
summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id)
end_user_id: 终端用户ID (end_user_id)
Returns:
(标题, 类型)元组,如果不存在则返回默认值
@@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService):
# 查询Summary节点的name(作为title)和memory_type(作为type)
query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
RETURN s.name AS title, s.memory_type AS type
"""
result = await self.neo4j_connector.execute_query(
query,
summary_id=summary_id,
group_id=end_user_id
end_user_id=end_user_id
)
if not result or len(result) == 0:
@@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args:
summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id)
end_user_id: 终端用户ID (end_user_id)
Returns:
前3个实体的name属性列表
@@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 按activation_value降序排序,返回前3个
query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity)
WHERE entity.activation_value IS NOT NULL
@@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService):
result = await self.neo4j_connector.execute_query(
query,
summary_id=summary_id,
group_id=end_user_id
end_user_id=end_user_id
)
# 提取实体名称
@@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService):
Args:
summary_id: Summary节点的ID
end_user_id: 终端用户ID (group_id)
end_user_id: 终端用户ID (end_user_id)
Returns:
所有Statement节点的statement属性内容列表
@@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 查询Summary节点指向的所有Statement节点
query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
WHERE stmt.statement IS NOT NULL AND stmt.statement <> ''
RETURN stmt.statement AS statement
@@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService):
result = await self.neo4j_connector.execute_query(
query,
summary_id=summary_id,
group_id=end_user_id
end_user_id=end_user_id
)
# 提取statement内容
@@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService):
# 1. 先查询所有情景记忆的总数(不受筛选条件限制)
total_all_query = """
MATCH (s:MemorySummary)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
RETURN count(s) AS total_all
"""
total_all_result = await self.neo4j_connector.execute_query(
total_all_query,
group_id=end_user_id
end_user_id=end_user_id
)
total_all = total_all_result[0]["total_all"] if total_all_result else 0
@@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService):
# 3. 构建Cypher查询
query = """
MATCH (s:MemorySummary)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
"""
# 添加时间范围过滤
@@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService):
ORDER BY s.created_at DESC
"""
params = {"group_id": end_user_id}
params = {"end_user_id": end_user_id}
if time_filter:
params["time_filter"] = time_filter
if title_keyword:
@@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService):
# 1. 查询指定的MemorySummary节点
query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
RETURN elementId(s) AS id, s.created_at AS created_at
"""
result = await self.neo4j_connector.execute_query(
query,
summary_id=summary_id,
group_id=end_user_id
end_user_id=end_user_id
)
# 2. 如果节点不存在,返回错误

View File

@@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 1. 查询情景记忆MemorySummary节点 ==========
episodic_query = """
MATCH (s:MemorySummary)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
RETURN elementId(s) AS id,
s.name AS title,
s.content AS content,
@@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService):
episodic_result = await self.neo4j_connector.execute_query(
episodic_query,
group_id=end_user_id
end_user_id=end_user_id
)
# 处理情景记忆数据
@@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 2. 查询语义记忆ExtractedEntity节点 ==========
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.group_id = $group_id
WHERE e.end_user_id = $end_user_id
AND e.is_explicit_memory = true
RETURN elementId(e) AS id,
e.name AS name,
@@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_result = await self.neo4j_connector.execute_query(
semantic_query,
group_id=end_user_id
end_user_id=end_user_id
)
# 处理语义记忆数据
@@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService):
# ========== 1. 先尝试查询情景记忆 ==========
episodic_query = """
MATCH (s:MemorySummary)
WHERE elementId(s) = $memory_id AND s.group_id = $group_id
WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id
RETURN s.name AS title,
s.content AS content,
s.created_at AS created_at
@@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService):
episodic_result = await self.neo4j_connector.execute_query(
episodic_query,
memory_id=memory_id,
group_id=end_user_id
end_user_id=end_user_id
)
if episodic_result and len(episodic_result) > 0:
@@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE elementId(e) = $memory_id
AND e.group_id = $group_id
AND e.end_user_id = $end_user_id
AND e.is_explicit_memory = true
RETURN e.name AS name,
e.description AS core_definition,
@@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService):
semantic_result = await self.neo4j_connector.execute_query(
semantic_query,
memory_id=memory_id,
group_id=end_user_id
end_user_id=end_user_id
)
if semantic_result and len(semantic_result) > 0:

View File

@@ -132,7 +132,7 @@ class MemoryForgetService:
async def _get_knowledge_stats(
self,
connector: Neo4jConnector,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
forgetting_threshold: float = 0.3
) -> Dict[str, Any]:
"""
@@ -140,7 +140,7 @@ class MemoryForgetService:
Args:
connector: Neo4j 连接器
group_id: 组ID可选
end_user_id: 组ID可选
forgetting_threshold: 遗忘阈值
Returns:
@@ -152,8 +152,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
WITH n,
@@ -172,8 +172,8 @@ class MemoryForgetService:
"""
params = {'threshold': forgetting_threshold}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await connector.execute_query(query, **params)
@@ -200,7 +200,7 @@ class MemoryForgetService:
async def _get_pending_forgetting_nodes(
self,
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
forgetting_threshold: float,
min_days_since_access: int,
limit: int = 20
@@ -212,7 +212,7 @@ class MemoryForgetService:
Args:
connector: Neo4j 连接器
group_id: 组ID
end_user_id: 组ID
forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数
limit: 返回节点数量限制
@@ -229,7 +229,7 @@ class MemoryForgetService:
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.group_id = $group_id
AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
@@ -250,7 +250,7 @@ class MemoryForgetService:
"""
params = {
'group_id': group_id,
'end_user_id': end_user_id,
'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str,
'limit': limit
@@ -291,7 +291,7 @@ class MemoryForgetService:
async def trigger_forgetting_cycle(
self,
db: Session,
group_id: str,
end_user_id: str,
max_merge_batch_size: Optional[int] = None,
min_days_since_access: Optional[int] = None,
config_id: Optional[int] = None
@@ -303,10 +303,10 @@ class MemoryForgetService:
Args:
db: 数据库会话
group_id: 组ID即终端用户ID必填
end_user_id: 组ID即终端用户ID必填
max_merge_batch_size: 最大融合批次大小(可选)
min_days_since_access: 最小未访问天数(可选)
config_id: 配置ID必填由控制器层通过 group_id 获取)
config_id: 配置ID必填由控制器层通过 end_user_id 获取)
Returns:
dict: 遗忘报告
@@ -319,7 +319,7 @@ class MemoryForgetService:
# 运行遗忘周期LLM 客户端将在需要时由 forgetting_strategy 内部获取)
report = await forgetting_scheduler.run_forgetting_cycle(
group_id=group_id,
end_user_id=end_user_id,
max_merge_batch_size=max_merge_batch_size,
min_days_since_access=min_days_since_access,
config_id=config_id,
@@ -338,7 +338,7 @@ class MemoryForgetService:
stats_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
AND n.group_id = $group_id
AND n.end_user_id = $end_user_id
RETURN
count(n) as total_nodes,
avg(n.activation_value) as average_activation,
@@ -347,7 +347,7 @@ class MemoryForgetService:
stats_results = await connector.execute_query(
stats_query,
group_id=group_id,
end_user_id=end_user_id,
threshold=config['forgetting_threshold']
)
@@ -364,7 +364,7 @@ class MemoryForgetService:
# 保存历史记录到数据库
self.history_repository.create(
db=db,
end_user_id=group_id,
end_user_id=end_user_id,
execution_time=execution_time,
merged_count=report['merged_count'],
failed_count=report['failed_count'],
@@ -376,7 +376,7 @@ class MemoryForgetService:
)
api_logger.info(
f"已保存遗忘周期历史记录: end_user_id={group_id}, "
f"已保存遗忘周期历史记录: end_user_id={end_user_id}, "
f"merged_count={report['merged_count']}"
)
@@ -465,7 +465,7 @@ class MemoryForgetService:
async def get_forgetting_stats(
self,
db: Session,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
@@ -475,7 +475,7 @@ class MemoryForgetService:
Args:
db: 数据库会话
group_id: 组ID可选
end_user_id: 组ID可选
config_id: 配置ID可选用于获取遗忘阈值
Returns:
@@ -493,8 +493,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
"""
if group_id:
activation_query += " AND n.group_id = $group_id"
if end_user_id:
activation_query += " AND n.end_user_id = $end_user_id"
activation_query += """
RETURN
@@ -506,8 +506,8 @@ class MemoryForgetService:
"""
params = {'threshold': forgetting_threshold}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
activation_results = await connector.execute_query(activation_query, **params)
@@ -539,8 +539,8 @@ class MemoryForgetService:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
"""
if group_id:
distribution_query += " AND n.group_id = $group_id"
if end_user_id:
distribution_query += " AND n.end_user_id = $end_user_id"
distribution_query += """
WITH n,
@@ -558,8 +558,8 @@ class MemoryForgetService:
"""
dist_params = {}
if group_id:
dist_params['group_id'] = group_id
if end_user_id:
dist_params['end_user_id'] = end_user_id
distribution_results = await connector.execute_query(distribution_query, **dist_params)
@@ -582,11 +582,11 @@ class MemoryForgetService:
# 获取最近7个日期的历史趋势数据每天取最后一次执行
recent_trends = []
try:
if group_id:
if end_user_id:
# 查询所有历史记录
history_records = self.history_repository.get_recent_by_end_user(
db=db,
end_user_id=group_id
end_user_id=end_user_id
)
# 按日期分组(一天可能有多次执行,取最后一次)
@@ -632,7 +632,7 @@ class MemoryForgetService:
# 获取待遗忘节点列表前20个满足遗忘条件的节点
pending_nodes = []
try:
if group_id:
if end_user_id:
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
@@ -643,7 +643,7 @@ class MemoryForgetService:
pending_nodes = await self._get_pending_forgetting_nodes(
connector=connector,
group_id=group_id,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
limit=20

View File

@@ -450,12 +450,12 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功")
async def write_rag(group_id, message, user_rag_memory_id):
async def write_rag(end_user_id, message, user_rag_memory_id):
"""
将消息写入 RAG 知识库
Args:
group_id: 组ID用作文件标题
end_user_id: 组ID用作文件标题
message: 消息内容
user_rag_memory_id: 知识库ID必须是有效的UUID
@@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id):
db = next(db_gen)
try:
create_data = CustomTextFileCreate(title=group_id, content=message)
create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt")
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document)
api_logger.info(f"查找文档结果: document_id={document}")
if document is not None:
@@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id):
return result
else:
# 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}")
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
result = await memory_konwledges_up(
kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id,
@@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id):
new_document_id = find_document_id_by_kb_and_filename(
db=db,
kb_id=user_rag_memory_id,
file_name=f"{group_id}.txt"
file_name=f"{end_user_id}.txt"
)
if new_document_id:
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
else:
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}")
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result
finally:
# 确保数据库会话被关闭

View File

@@ -183,7 +183,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"config_name": config.config_name,
"config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
"group_id": config.group_id,
"end_user_id": config.end_user_id,
"user_id": config.user_id,
"apply_id": config.apply_id,
"llm_id": config.llm_id,
@@ -391,7 +391,7 @@ _neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DIALOGUE,
group_id=end_user_id,
end_user_id=end_user_id,
)
data = {"search_for": "dialogue", "num": result[0]["num"]}
return data
@@ -400,7 +400,7 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_CHUNK,
group_id=end_user_id,
end_user_id=end_user_id,
)
data = {"search_for": "chunk", "num": result[0]["num"]}
return data
@@ -409,7 +409,7 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_STATEMENT,
group_id=end_user_id,
end_user_id=end_user_id,
)
data = {"search_for": "statement", "num": result[0]["num"]}
return data
@@ -418,7 +418,7 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY,
group_id=end_user_id,
end_user_id=end_user_id,
)
data = {"search_for": "entity", "num": result[0]["num"]}
return data
@@ -427,7 +427,7 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
end_user_id=end_user_id,
)
# 检查结果是否为空或长度不足
@@ -462,7 +462,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
end_user_id=end_user_id,
)
# 检查结果是否为空或长度不足
@@ -493,7 +493,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DETIALS,
group_id=end_user_id,
end_user_id=end_user_id,
)
return result
@@ -501,7 +501,7 @@ async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, An
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_EDGES,
group_id=end_user_id,
end_user_id=end_user_id,
)
return result
@@ -510,7 +510,7 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
"""搜索所有实体之间的关系网络group 维度)。"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
group_id=end_user_id,
end_user_id=end_user_id,
)
# 对source_node 和 target_node 的 fact_summary进行截取只截取前三条的内容需要提取前三条“来源”
for item in result:

View File

@@ -91,7 +91,7 @@ async def run_pilot_extraction(
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
end_user_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},

View File

@@ -155,10 +155,10 @@ class MemoryInsightHelper:
"""
query = """
MATCH (d:Dialogue)
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records:
return []
@@ -211,17 +211,17 @@ class MemoryInsightHelper:
async def get_social_connections(self) -> dict | None:
"""Find the user with whom the most memories are shared."""
query = """
MATCH (c1:Chunk {group_id: $group_id})
MATCH (c1:Chunk {end_user_id: $end_user_id})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records or not records[0].get("other_user_id"):
return None
@@ -230,7 +230,7 @@ class MemoryInsightHelper:
time_range_query = """
MATCH (c:Chunk)
WHERE c.group_id IN [$user_id, $other_user_id]
WHERE c.end_user_id IN [$user_id, $other_user_id]
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(
@@ -294,11 +294,11 @@ class UserSummaryHelper:
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit)
records = []
for r in rows:
try:
@@ -357,6 +357,101 @@ class UserMemoryService:
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
return data
def update_end_user_profile(
self,
db: Session,
end_user_id: str,
profile_update: Any
) -> Dict[str, Any]:
"""
更新终端用户的基本信息
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
profile_update: 包含更新字段的 Pydantic 模型
Returns:
{
"success": bool,
"data": dict, # 更新后的用户档案数据
"error": Optional[str]
}
"""
try:
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return {
"success": False,
"data": None,
"error": "终端用户不存在"
}
# 获取更新数据(排除 end_user_id 字段)
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
from app.core.api_key_utils import timestamp_to_datetime
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
# 更新字段
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新时间戳
end_user.updated_at = datetime.now()
end_user.updatetime_profile = datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
from app.schemas.end_user_schema import EndUserProfileResponse
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return {
"success": True,
"data": self.convert_profile_to_dict_with_timestamp(profile_data),
"error": None
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"data": None,
"error": "无效的用户ID格式"
}
except Exception as e:
db.rollback()
logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return {
"success": False,
"data": None,
"error": str(e)
}
async def get_cached_memory_insight(
self,
db: Session,
@@ -1057,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
import re
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
try:
# 1) 收集上下文数据
@@ -1178,10 +1273,10 @@ async def analytics_node_statistics(
if end_user_id:
query = f"""
MATCH (n:{node_type})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
@@ -1292,10 +1387,10 @@ async def analytics_memory_types(
# 查询 Statement 节点数量
query = """
MATCH (n:Statement)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
statement_count = result[0]["count"] if result and len(result) > 0 else 0
# 取三分之一作为隐性记忆数量
implicit_count = round(statement_count / 3)
@@ -1409,7 +1504,7 @@ async def analytics_graph_data(
包含节点、边和统计信息的字典
"""
try:
# 1. 获取 group_id
# 1. 获取 end_user_id
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
@@ -1433,7 +1528,7 @@ async def analytics_graph_data(
# 基于中心节点的扩展查询
node_query = f"""
MATCH path = (center)-[*1..{depth}]-(connected)
WHERE center.group_id = $group_id
WHERE center.end_user_id = $end_user_id
AND elementId(center) = $center_node_id
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
UNWIND all_nodes as n
@@ -1444,7 +1539,7 @@ async def analytics_graph_data(
LIMIT $limit
"""
node_params = {
"group_id": end_user_id,
"end_user_id": end_user_id,
"center_node_id": center_node_id,
"limit": limit
}
@@ -1452,7 +1547,7 @@ async def analytics_graph_data(
# 按节点类型过滤查询
node_query = """
MATCH (n)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND labels(n)[0] IN $node_types
RETURN
elementId(n) as id,
@@ -1461,7 +1556,7 @@ async def analytics_graph_data(
LIMIT $limit
"""
node_params = {
"group_id": end_user_id,
"end_user_id": end_user_id,
"node_types": node_types,
"limit": limit
}
@@ -1469,7 +1564,7 @@ async def analytics_graph_data(
# 查询所有节点
node_query = """
MATCH (n)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) as id,
labels(n)[0] as label,
@@ -1477,7 +1572,7 @@ async def analytics_graph_data(
LIMIT $limit
"""
node_params = {
"group_id": end_user_id,
"end_user_id": end_user_id,
"limit": limit
}

View File

@@ -382,12 +382,12 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a read message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
end_user_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process
history: Conversation history
search_switch: Search switch parameter
@@ -408,7 +408,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
@@ -420,7 +420,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db = next(get_db())
try:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
@@ -448,7 +448,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -464,7 +464,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
return {
"status": "FAILURE",
"error": detailed_error,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -472,11 +472,11 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
def write_message_task(self, end_user_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
end_user_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
config_id: Optional configuration ID
@@ -489,7 +489,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}")
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
start_time = time.time()
# Resolve config_id if None
@@ -499,7 +499,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
@@ -512,7 +512,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
try:
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
service = MemoryAgentService()
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
except Exception as e:
@@ -547,7 +547,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -566,7 +566,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
return {
"status": "FAILURE",
"error": detailed_error,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -612,7 +612,7 @@ def check_read_service_task() -> Dict[str, str]:
payload = {
"user_id": "健康检查",
"apply_id": "健康检查",
"group_id": "健康检查",
"end_user_id": "健康检查",
"message": "你好",
"history": [],
"search_switch": "2",
@@ -1112,7 +1112,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
# 运行遗忘周期
report = await forget_service.trigger_forgetting(
db=db,
group_id=None, # 处理所有组
end_user_id=None, # 处理所有组
config_id=config_id
)

View File

@@ -7,10 +7,6 @@ services:
- "8002:8000"
env_file:
- .env
environment:
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
@@ -19,20 +15,53 @@ services:
networks:
- default
- celery
depends_on:
- worker-memory
- worker-document
# Celery worker
worker:
# Memory worker - Memory read/write tasks (threads pool for asyncio)
worker-memory:
image: redbear-mem-open:latest
container_name: worker
container_name: worker-memory
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
restart: unless-stopped
networks:
- celery
# Document worker - Document parsing tasks (prefork for CPU-bound)
worker-document:
image: redbear-mem-open:latest
container_name: worker-document
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
container_name: celery-beat
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app beat --loglevel=info
restart: unless-stopped
networks:
- celery
depends_on:
- worker-memory
networks:
celery:

View File

@@ -13,6 +13,7 @@ dependencies = [
"bcrypt==5.0.0",
"billiard==4.2.2",
"celery==5.5.3",
"flower==2.0.1",
"cffi==2.0.0",
"click==8.3.0",
"click-didyoumean==0.3.1",
@@ -138,6 +139,7 @@ dependencies = [
"python-calamine>=0.4.0",
"xlrd==2.0.2",
"deprecated>=1.3.1",
"flower>=2.0.1",
]
[tool.pytest.ini_options]

View File

@@ -6,6 +6,7 @@ async-timeout==5.0.1
bcrypt==5.0.0
billiard==4.2.2
celery==5.5.3
flower==2.0.1
cffi==2.0.0
click==8.3.0
click-didyoumean==0.3.1