Merge branch 'refs/heads/develop' into fix/memory_bug_fix
This commit is contained in:
14
README.md
14
README.md
@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
|
|||||||
## License
|
## License
|
||||||
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
||||||
|
|
||||||
## Acknowledgements & Community
|
## Community & Support
|
||||||
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
|
|
||||||
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
|
Join our community to ask questions, share your work, and connect with fellow developers.
|
||||||
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
|
||||||
|
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
|
||||||
|
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
|
||||||
|
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
|
||||||
|
- **WeChat**: Scan the QR code below to join our WeChat community group.
|
||||||
|
- 
|
||||||
|
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||||
|
|||||||
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@@ -14,28 +15,13 @@ celery_app = Celery(
|
|||||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 配置使用本地队列,避免与远程 worker 冲突
|
# Default queue for unrouted tasks
|
||||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||||
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
|
||||||
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
|
||||||
|
|
||||||
# macOS 兼容性配置
|
# macOS 兼容性配置
|
||||||
import platform
|
if platform.system() == 'Darwin':
|
||||||
|
|
||||||
if platform.system() == 'Darwin': # macOS
|
|
||||||
# 设置环境变量解决 fork 问题
|
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 使用 solo 池避免多进程问题
|
|
||||||
celery_app.conf.worker_pool = 'solo'
|
|
||||||
|
|
||||||
# 设置唯一的节点名称
|
|
||||||
import socket
|
|
||||||
import time
|
|
||||||
hostname = socket.gethostname()
|
|
||||||
timestamp = int(time.time())
|
|
||||||
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
|
||||||
|
|
||||||
# Celery 配置
|
# Celery 配置
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
# 序列化
|
# 序列化
|
||||||
@@ -52,36 +38,47 @@ celery_app.conf.update(
|
|||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|
||||||
# 超时设置
|
# 超时设置
|
||||||
task_time_limit=30 * 60, # 30 分钟硬超时
|
task_time_limit=1800, # 30分钟硬超时
|
||||||
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
task_soft_time_limit=1500, # 25分钟软超时
|
||||||
|
|
||||||
# Worker 设置 - 针对 macOS 优化
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
|
||||||
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存 1 小时
|
result_expires=3600, # 结果保存1小时
|
||||||
|
|
||||||
# 任务确认设置
|
# 任务确认设置
|
||||||
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
task_acks_late=True,
|
||||||
worker_disable_rate_limits=True, # 禁用速率限制
|
task_reject_on_worker_lost=True,
|
||||||
|
worker_disable_rate_limits=True,
|
||||||
|
|
||||||
# 任务路由(可选,用于不同队列)
|
# FLower setting
|
||||||
# task_routes={
|
worker_send_task_events=True,
|
||||||
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
task_send_sent_event=True,
|
||||||
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
|
||||||
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
# task routing
|
||||||
# 'tasks.process_item': {'queue': 'default'},
|
task_routes={
|
||||||
# },
|
# Memory tasks → memory_tasks queue (threads worker)
|
||||||
|
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
|
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||||
|
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||||
|
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||||
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||||
|
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 自动发现任务模块
|
# 自动发现任务模块
|
||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
|
||||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
|
||||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||||
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
|
|||||||
|
|
||||||
# 构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
|
|
||||||
# "check-read-service": {
|
|
||||||
# "task": "app.core.memory.agent.health.check_read_service",
|
|
||||||
# "schedule": health_schedule,
|
|
||||||
# "args": (),
|
|
||||||
# },
|
|
||||||
"run-workspace-reflection": {
|
"run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
"task": "app.tasks.workspace_reflection_task",
|
||||||
"schedule": workspace_reflection_schedule,
|
"schedule": workspace_reflection_schedule,
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from app.db import get_db
|
|||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
@@ -123,7 +125,7 @@ async def write_server(
|
|||||||
Write service endpoint - processes write operations synchronously
|
Write service endpoint - processes write operations synchronously
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Write request containing message and group_id
|
user_input: Write request containing message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response with write operation status
|
Response with write operation status
|
||||||
@@ -158,11 +160,11 @@ async def write_server(
|
|||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
user_input.group_id,
|
user_input.end_user_id,
|
||||||
user_input.message,
|
user_input.messages,
|
||||||
config_id,
|
config_id,
|
||||||
db,
|
db,
|
||||||
storage_type,
|
storage_type,
|
||||||
@@ -191,7 +193,7 @@ async def write_server_async(
|
|||||||
Async write service endpoint - enqueues write processing to Celery
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Write request containing message and group_id
|
user_input: Write request containing message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Task ID for tracking async operation
|
Task ID for tracking async operation
|
||||||
@@ -221,7 +223,7 @@ async def write_server_async(
|
|||||||
try:
|
try:
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
"app.core.memory.agent.write_message",
|
||||||
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
args=[user_input.end_user_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
|
||||||
@@ -247,7 +249,7 @@ async def read_server(
|
|||||||
- "2": Direct answer based on context
|
- "2": Direct answer based on context
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Read request with message, history, search_switch, and group_id
|
user_input: Read request with message, history, search_switch, and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response with query answer
|
Response with query answer
|
||||||
@@ -271,12 +273,13 @@ async def read_server(
|
|||||||
name="USER_RAG_MERORY",
|
name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
user_input.group_id,
|
user_input.end_user_id,
|
||||||
user_input.message,
|
user_input.message,
|
||||||
user_input.history,
|
user_input.history,
|
||||||
user_input.search_switch,
|
user_input.search_switch,
|
||||||
@@ -285,6 +288,19 @@ async def read_server(
|
|||||||
storage_type,
|
storage_type,
|
||||||
user_rag_memory_id
|
user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
if str(user_input.search_switch) == "2":
|
||||||
|
retrieve_info = result['answer']
|
||||||
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||||
|
query = user_input.message
|
||||||
|
|
||||||
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
|
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
|
retrieve_info=retrieve_info,
|
||||||
|
history=history,
|
||||||
|
query=query,
|
||||||
|
config_id=config_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -382,7 +398,7 @@ async def read_server_async(
|
|||||||
try:
|
try:
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.read_message",
|
"app.core.memory.agent.read_message",
|
||||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||||
config_id, storage_type, user_rag_memory_id]
|
config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Read task queued: {task.id}")
|
api_logger.info(f"Read task queued: {task.id}")
|
||||||
@@ -426,7 +442,7 @@ async def get_read_task_result(
|
|||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"result": task_result.get("result"),
|
"result": task_result.get("result"),
|
||||||
"group_id": task_result.get("group_id"),
|
"end_user_id": task_result.get("end_user_id"),
|
||||||
"elapsed_time": task_result.get("elapsed_time"),
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
},
|
},
|
||||||
@@ -503,7 +519,7 @@ async def get_write_task_result(
|
|||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"result": task_result.get("result"),
|
"result": task_result.get("result"),
|
||||||
"group_id": task_result.get("group_id"),
|
"end_user_id": task_result.get("end_user_id"),
|
||||||
"elapsed_time": task_result.get("elapsed_time"),
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
},
|
},
|
||||||
@@ -557,15 +573,30 @@ async def status_type(
|
|||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Request containing user message and group_id
|
user_input: Request containing user message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Type classification result
|
Type classification result
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||||
try:
|
try:
|
||||||
|
# 获取标准化的消息列表
|
||||||
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
|
||||||
|
# 将消息列表转换为字符串用于分类
|
||||||
|
# 只取最后一条用户消息进行分类
|
||||||
|
last_user_message = ""
|
||||||
|
for msg in reversed(messages_list):
|
||||||
|
if msg.get('role') == 'user':
|
||||||
|
last_user_message = msg.get('content', '')
|
||||||
|
break
|
||||||
|
|
||||||
|
if not last_user_message:
|
||||||
|
# 如果没有用户消息,使用所有消息的内容
|
||||||
|
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||||
|
|
||||||
result = await memory_agent_service.classify_message_type(
|
result = await memory_agent_service.classify_message_type(
|
||||||
user_input.message,
|
user_input.messages,
|
||||||
user_input.config_id,
|
user_input.config_id,
|
||||||
db
|
db
|
||||||
)
|
)
|
||||||
@@ -588,7 +619,7 @@ async def get_knowledge_type_stats_api(
|
|||||||
会对缺失类型补 0,返回字典形式。
|
会对缺失类型补 0,返回字典形式。
|
||||||
可选按状态过滤。
|
可选按状态过滤。
|
||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas.memory_agent_schema import End_User_Information
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||||
@@ -40,54 +39,7 @@ def get_workspace_total_end_users(
|
|||||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||||
return success(data=total_end_users, msg="用户数量获取成功")
|
return success(data=total_end_users, msg="用户数量获取成功")
|
||||||
|
|
||||||
@router.post("/update/end_users", response_model=ApiResponse)
|
|
||||||
async def update_workspace_end_users(
|
|
||||||
user_input: End_User_Information,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
更新工作空间的宿主信息
|
|
||||||
"""
|
|
||||||
username = user_input.end_user_name # 要更新的用户名
|
|
||||||
end_user_input_id = user_input.id # 宿主ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
|
|
||||||
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 导入更新函数
|
|
||||||
from app.repositories.end_user_repository import update_end_user_other_name
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# 转换 end_user_id 为 UUID 类型
|
|
||||||
end_user_uuid = uuid.UUID(end_user_input_id)
|
|
||||||
|
|
||||||
# 直接更新数据库中的 other_name 字段
|
|
||||||
updated_count = update_end_user_other_name(
|
|
||||||
db=db,
|
|
||||||
end_user_id=end_user_uuid,
|
|
||||||
other_name=username
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
|
|
||||||
|
|
||||||
return success(
|
|
||||||
data={
|
|
||||||
"updated_count": updated_count,
|
|
||||||
"end_user_id": end_user_input_id,
|
|
||||||
"updated_other_name": username
|
|
||||||
},
|
|
||||||
msg=f"成功更新 {updated_count} 个宿主的信息"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"更新宿主信息失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"更新宿主信息失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from app.services.memory_storage_service import (
|
|||||||
search_dialogue,
|
search_dialogue,
|
||||||
search_edges,
|
search_edges,
|
||||||
search_entity,
|
search_entity,
|
||||||
search_entity_graph,
|
|
||||||
search_statement,
|
search_statement,
|
||||||
)
|
)
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
@@ -412,21 +411,7 @@ async def search_entity_edges(
|
|||||||
api_logger.error(f"Search edges failed: {str(e)}")
|
api_logger.error(f"Search edges failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
@router.get("/search/entity_graph", response_model=ApiResponse)
|
|
||||||
async def search_for_entity_graph(
|
|
||||||
end_user_id: Optional[str] = None,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
搜索所有实体之间的关系网络
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
|
|
||||||
try:
|
|
||||||
result = await search_entity_graph(end_user_id)
|
|
||||||
return success(data=result, msg="查询成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Search entity graph failed: {str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
|||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
@@ -50,6 +50,7 @@ async def write_memory_api_service(
|
|||||||
config_id=payload.config_id,
|
config_id=payload.config_id,
|
||||||
storage_type=payload.storage_type,
|
storage_type=payload.storage_type,
|
||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
|
tenant_id=api_key_auth.tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||||
|
|||||||
@@ -351,12 +351,11 @@ async def update_end_user_profile(
|
|||||||
|
|
||||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||||
所有字段都是可选的,只更新提供的字段。
|
所有字段都是可选的,只更新提供的字段。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
end_user_id = profile_update.end_user_id
|
end_user_id = profile_update.end_user_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 验证工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
@@ -366,57 +365,24 @@ async def update_end_user_profile(
|
|||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# 调用 Service 层处理业务逻辑
|
||||||
# 查询终端用户
|
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
|
||||||
|
|
||||||
if not end_user:
|
if result["success"]:
|
||||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
return success(data=result["data"], msg="更新成功")
|
||||||
|
else:
|
||||||
|
error_msg = result["error"]
|
||||||
|
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
# 更新字段(只更新提供的字段,排除 end_user_id)
|
# 根据错误类型映射到合适的业务错误码
|
||||||
# 允许 None 值来重置字段(如 hire_date)
|
if error_msg == "终端用户不存在":
|
||||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||||
|
elif error_msg == "无效的用户ID格式":
|
||||||
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||||
if 'hire_date' in update_data:
|
else:
|
||||||
hire_date_timestamp = update_data['hire_date']
|
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||||
if hire_date_timestamp is not None:
|
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||||
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
|
||||||
# 如果是 None,保持 None(允许清空)
|
|
||||||
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(end_user, field, value)
|
|
||||||
|
|
||||||
# 更新 updated_at 时间戳
|
|
||||||
end_user.updated_at = datetime.datetime.now()
|
|
||||||
|
|
||||||
# 更新 updatetime_profile 为当前时间
|
|
||||||
end_user.updatetime_profile = datetime.datetime.now()
|
|
||||||
|
|
||||||
# 提交更改
|
|
||||||
db.commit()
|
|
||||||
db.refresh(end_user)
|
|
||||||
|
|
||||||
# 构建响应数据
|
|
||||||
profile_data = EndUserProfileResponse(
|
|
||||||
id=end_user.id,
|
|
||||||
other_name=end_user.other_name,
|
|
||||||
position=end_user.position,
|
|
||||||
department=end_user.department,
|
|
||||||
contact=end_user.contact,
|
|
||||||
phone=end_user.phone,
|
|
||||||
hire_date=end_user.hire_date,
|
|
||||||
updatetime_profile=end_user.updatetime_profile
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
|
||||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",
|
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class LangChainAgent:
|
|||||||
userid=end_user_end,
|
userid=end_user_end,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
apply_id=end_user_end,
|
apply_id=end_user_end,
|
||||||
group_id=end_user_end,
|
end_user_id=end_user_end,
|
||||||
aimessages=aimessages
|
aimessages=aimessages
|
||||||
)
|
)
|
||||||
store.delete_duplicate_sessions()
|
store.delete_duplicate_sessions()
|
||||||
@@ -173,16 +173,67 @@ class LangChainAgent:
|
|||||||
retrieved_content.append({query: aimessages})
|
retrieved_content.append({query: aimessages})
|
||||||
return messagss_list,retrieved_content
|
return messagss_list,retrieved_content
|
||||||
|
|
||||||
|
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
Args:
|
||||||
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
user_message: 用户消息内容
|
||||||
|
ai_message: AI 回复内容
|
||||||
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
else:
|
else:
|
||||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
# Neo4j 模式:使用结构化消息列表
|
||||||
user_rag_memory_id)
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if user_message:
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if ai_message:
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 调用 Celery 任务,传递结构化消息列表
|
||||||
|
# 数据流:
|
||||||
|
# 1. structured_messages 传递给 write_message_task
|
||||||
|
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||||
|
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||||
|
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||||
|
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||||
|
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||||
|
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # group_id: 用户ID
|
||||||
|
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
actual_config_id, # config_id: 配置ID
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@@ -227,29 +278,30 @@ class LangChainAgent:
|
|||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||||
|
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
# history_term_memory = history_term_memory_result[0]
|
||||||
|
# db_for_memory = next(get_db())
|
||||||
|
# if memory_flag:
|
||||||
|
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||||
|
# history_term_memory = ';'.join(history_term_memory)
|
||||||
|
# retrieved_content = history_term_memory_result[1]
|
||||||
|
# print(retrieved_content)
|
||||||
|
# # 为长期记忆操作获取新的数据库连接
|
||||||
|
# try:
|
||||||
|
# repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
# repo.upsert(end_user_id, retrieved_content)
|
||||||
|
# logger.info(
|
||||||
|
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||||
|
# raise
|
||||||
|
# finally:
|
||||||
|
# db_for_memory.close()
|
||||||
|
|
||||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
# # 长期记忆写入(
|
||||||
history_term_memory = history_term_memory_result[0]
|
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
db_for_memory = next(get_db())
|
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
if memory_flag:
|
|
||||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
|
||||||
history_term_memory = ';'.join(history_term_memory)
|
|
||||||
retrieved_content = history_term_memory_result[1]
|
|
||||||
print(retrieved_content)
|
|
||||||
# 为长期记忆操作获取新的数据库连接
|
|
||||||
try:
|
|
||||||
repo = LongTermMemoryRepository(db_for_memory)
|
|
||||||
repo.upsert(end_user_id, retrieved_content)
|
|
||||||
logger.info(
|
|
||||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to write to LongTermMemory: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db_for_memory.close()
|
|
||||||
|
|
||||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
|
||||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context)
|
||||||
@@ -277,8 +329,10 @@ class LangChainAgent:
|
|||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
await self.term_memory_save(message_chat,end_user_id,content)
|
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -346,27 +400,27 @@ class LangChainAgent:
|
|||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
|
# # TODO 乐力齐
|
||||||
|
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
# history_term_memory = history_term_memory_result[0]
|
||||||
|
# if memory_flag:
|
||||||
|
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||||
|
# history_term_memory = ';'.join(history_term_memory)
|
||||||
|
# retrieved_content = history_term_memory_result[1]
|
||||||
|
# db_for_memory = next(get_db())
|
||||||
|
# try:
|
||||||
|
# repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
# repo.upsert(end_user_id, retrieved_content)
|
||||||
|
# logger.info(
|
||||||
|
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
# # 长期记忆写入
|
||||||
|
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Failed to write to long term memory: {e}")
|
||||||
|
# finally:
|
||||||
|
# db_for_memory.close()
|
||||||
|
|
||||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
history_term_memory = history_term_memory_result[0]
|
|
||||||
if memory_flag:
|
|
||||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
|
||||||
history_term_memory = ';'.join(history_term_memory)
|
|
||||||
retrieved_content = history_term_memory_result[1]
|
|
||||||
db_for_memory = next(get_db())
|
|
||||||
try:
|
|
||||||
repo = LongTermMemoryRepository(db_for_memory)
|
|
||||||
repo.upsert(end_user_id, retrieved_content)
|
|
||||||
logger.info(
|
|
||||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
|
||||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
|
||||||
history_term_memory, actual_config_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to write to long term memory: {e}")
|
|
||||||
finally:
|
|
||||||
db_for_memory.close()
|
|
||||||
|
|
||||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context)
|
||||||
@@ -418,8 +472,10 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
|||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""问题处理节点服务类"""
|
"""问题处理节点服务类"""
|
||||||
|
|
||||||
@@ -25,17 +26,19 @@ class ProblemNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""问题分解节点"""
|
"""问题分解节点"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
@@ -77,7 +80,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
split_result_dict = []
|
split_result_dict = []
|
||||||
for index, item in enumerate(json.loads(split_result)):
|
for index, item in enumerate(json.loads(split_result)):
|
||||||
split_data = {
|
split_data = {
|
||||||
"id": f"Q{index+1}",
|
"id": f"Q{index + 1}",
|
||||||
"question": item['extended_question'],
|
"question": item['extended_question'],
|
||||||
"type": item['type'],
|
"type": item['type'],
|
||||||
"reason": item['reason']
|
"reason": item['reason']
|
||||||
@@ -130,13 +133,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
# 返回更新后的状态,包含spit_context字段
|
# 返回更新后的状态,包含spit_context字段
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""问题扩展节点"""
|
"""问题扩展节点"""
|
||||||
# 获取原始数据和分解结果
|
# 获取原始数据和分解结果
|
||||||
start = time.time()
|
start = time.time()
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
data = state.get('spit_data', '')['context']
|
data = state.get('spit_data', '')['context']
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
@@ -152,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
databasets = {}
|
databasets = {}
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
@@ -243,6 +247,3 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {"problem_extension": result}
|
return {"problem_extension": result}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
|||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
async def rag_knowledge(state,question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension=state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type=state.get('storage_type', '')
|
||||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||||
group_id=state.get('group_id', '')
|
end_user_id=state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original=state.get('data', '')
|
original=state.get('data', '')
|
||||||
problem_list=[]
|
problem_list=[]
|
||||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
try:
|
try:
|
||||||
# Prepare search parameters based on storage type
|
# Prepare search parameters based on storage type
|
||||||
search_params = {
|
search_params = {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"question": question,
|
"question": question,
|
||||||
"return_raw_results": True
|
"return_raw_results": True
|
||||||
}
|
}
|
||||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取group_id
|
# 从state中获取end_user_id
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start=time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original = state.get('data', '')
|
original = state.get('data', '')
|
||||||
problem_list = []
|
problem_list = []
|
||||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# 创建异步任务处理单个问题
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
SummaryResponse,
|
SummaryResponse,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -35,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
|||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
group_id = state.get("group_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||||
@@ -123,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
group_id = state.get("group_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
user_id=group_id,
|
user_id=end_user_id,
|
||||||
query=data,
|
query=data,
|
||||||
apply_id=group_id,
|
apply_id=end_user_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
ai_response=aimessages
|
ai_response=aimessages
|
||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
@@ -176,13 +175,14 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
data=state.get("data", '')
|
data=state.get("data", '')
|
||||||
group_id=state.get("group_id", '')
|
end_user_id=state.get("end_user_id", '')
|
||||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
history = await summary_history( state)
|
history = await summary_history( state)
|
||||||
search_params = {
|
search_params = {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True
|
"return_raw_results": True,
|
||||||
|
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
|||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
|
|||||||
@@ -9,26 +9,21 @@ async def write_node(state: WriteState) -> WriteState:
|
|||||||
Write data to the database/file system.
|
Write data to the database/file system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: FastMCP context for dependency injection
|
|
||||||
content: Data content to write
|
content: Data content to write
|
||||||
user_id: User identifier
|
end_user_id: End user identifier
|
||||||
apply_id: Application identifier
|
|
||||||
group_id: Group identifier
|
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||||
"""
|
"""
|
||||||
content=state.get('data','')
|
content=state.get('data','')
|
||||||
group_id=state.get('group_id','')
|
end_user_id=state.get('end_user_id','')
|
||||||
memory_config=state.get('memory_config', '')
|
memory_config=state.get('memory_config', '')
|
||||||
try:
|
try:
|
||||||
result=await write(
|
result=await write(
|
||||||
content=content,
|
end_user_id=end_user_id,
|
||||||
user_id=group_id,
|
|
||||||
apply_id=group_id,
|
|
||||||
group_id=group_id,
|
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
|
messages=content, # 修复:使用正确的参数名 messages
|
||||||
)
|
)
|
||||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ async def make_read_graph():
|
|||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
|
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
@@ -80,7 +79,7 @@ async def make_read_graph():
|
|||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "昨天有什么好看的电影"
|
message = "昨天有什么好看的电影"
|
||||||
group_id = '88a459f5_text09' # 组ID
|
end_user_id = '88a459f5_text09' # 组ID
|
||||||
storage_type = 'neo4j' # 存储类型
|
storage_type = 'neo4j' # 存储类型
|
||||||
search_switch = '1' # 搜索开关
|
search_switch = '1' # 搜索开关
|
||||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||||
@@ -96,9 +95,9 @@ async def main():
|
|||||||
start=time.time()
|
start=time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
|
|||||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
|||||||
class TimeRetrievalInput(BaseModel):
|
class TimeRetrievalInput(BaseModel):
|
||||||
"""时间检索工具的输入模式"""
|
"""时间检索工具的输入模式"""
|
||||||
context: str = Field(description="用户输入的查询内容")
|
context: str = Field(description="用户输入的查询内容")
|
||||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||||
|
|
||||||
def create_time_retrieval_tool(group_id: str):
|
def create_time_retrieval_tool(end_user_id: str):
|
||||||
"""
|
"""
|
||||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_temporal_result_fields(data):
|
def clean_temporal_result_fields(data):
|
||||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||||
显式接收参数:
|
显式接收参数:
|
||||||
- context: 查询上下文内容
|
- context: 查询上下文内容
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
- clean_output: 是否清理输出中的元数据字段
|
||||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||||
"""
|
"""
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 使用传入的参数或默认值
|
# 使用传入的参数或默认值
|
||||||
actual_group_id = group_id_param or group_id
|
actual_end_user_id = end_user_id_param or end_user_id
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 基本时间搜索
|
# 基本时间搜索
|
||||||
results = await search_by_temporal(
|
results = await search_by_temporal(
|
||||||
group_id=actual_group_id,
|
end_user_id=actual_end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
end_date=actual_end_date,
|
end_date=actual_end_date,
|
||||||
limit=10
|
limit=10
|
||||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
|||||||
# 关键词时间搜索
|
# 关键词时间搜索
|
||||||
results = await search_by_keyword_temporal(
|
results = await search_by_keyword_temporal(
|
||||||
query_text=context,
|
query_text=context,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
end_date=actual_end_date,
|
end_date=actual_end_date,
|
||||||
limit=15
|
limit=15
|
||||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
**search_params: 搜索参数,包含group_id, limit, include等
|
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_result_fields(data):
|
def clean_result_fields(data):
|
||||||
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
group_id: str = None,
|
end_user_id: str = None,
|
||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
context: 查询内容
|
context: 查询内容
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: 结果数量限制
|
||||||
group_id: 组ID,用于过滤搜索结果
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
rerank_alpha: 重排序权重参数
|
rerank_alpha: 重排序权重参数
|
||||||
use_forgetting_rerank: 是否使用遗忘重排序
|
use_forgetting_rerank: 是否使用遗忘重排序
|
||||||
use_llm_rerank: 是否使用LLM重排序
|
use_llm_rerank: 是否使用LLM重排序
|
||||||
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
final_params = {
|
final_params = {
|
||||||
"query_text": context,
|
"query_text": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"group_id": group_id or search_params.get("group_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||||
"output_path": None, # 不保存到文件
|
"output_path": None, # 不保存到文件
|
||||||
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
group_id: str = None,
|
end_user_id: str = None,
|
||||||
clean_output: bool = True
|
clean_output: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
context: 查询内容
|
context: 查询内容
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: 结果数量限制
|
||||||
group_id: 组ID,用于过滤搜索结果
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: 是否清理输出中的元数据字段
|
||||||
"""
|
"""
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
"context": context,
|
"context": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"clean_output": clean_output
|
"clean_output": clean_output
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ async def make_write_graph():
|
|||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
tools: MCP tools loaded from session
|
tools: MCP tools loaded from session
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
workflow = StateGraph(WriteState)
|
workflow = StateGraph(WriteState)
|
||||||
@@ -49,7 +49,7 @@ async def make_write_graph():
|
|||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "今天周一"
|
message = "今天周一"
|
||||||
group_id = 'new_2025test1103' # 组ID
|
end_user_id = 'new_2025test1103' # 组ID
|
||||||
|
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
@@ -61,9 +61,9 @@ async def main():
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async with make_write_graph() as graph:
|
async with make_write_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||||
|
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
async for update_event in graph.astream(
|
async for update_event in graph.astream(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None
|
user_rag_memory_id: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
|||||||
tool_call_id: Extracted tool call identifier
|
tool_call_id: Extracted tool call identifier
|
||||||
search_switch: Search routing parameter
|
search_switch: Search routing parameter
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
storage_type: Storage type for the workspace (optional)
|
storage_type: Storage type for the workspace (optional)
|
||||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
|||||||
base_args = {
|
base_args = {
|
||||||
"usermessages": tool_call_id,
|
"usermessages": tool_call_id,
|
||||||
"apply_id": apply_id,
|
"apply_id": apply_id,
|
||||||
"group_id": group_id
|
"end_user_id": end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class SearchService:
|
|||||||
|
|
||||||
async def execute_hybrid_search(
|
async def execute_hybrid_search(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
question: str,
|
question: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
@@ -105,7 +105,7 @@ class SearchService:
|
|||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group identifier for filtering results
|
end_user_id: Group identifier for filtering results
|
||||||
question: Search query text
|
question: Search query text
|
||||||
limit: Maximum number of results to return (default: 5)
|
limit: Maximum number of results to return (default: 5)
|
||||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||||
@@ -130,7 +130,7 @@ class SearchService:
|
|||||||
answer = await run_hybrid_search(
|
answer = await run_hybrid_search(
|
||||||
query_text=cleaned_query,
|
query_text=cleaned_query,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
@@ -186,7 +186,7 @@ class SearchService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty results on failure
|
# Return empty results on failure
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class SessionService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str
|
end_user_id: str
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve conversation history from Redis.
|
Retrieve conversation history from Redis.
|
||||||
@@ -67,20 +67,20 @@ class SessionService:
|
|||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of conversation history items with Query and Answer keys
|
List of conversation history items with Query and Answer keys
|
||||||
Returns empty list if no history found or on error
|
Returns empty list if no history found or on error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||||
|
|
||||||
# Validate history structure
|
# Validate history structure
|
||||||
if not isinstance(history, list):
|
if not isinstance(history, list):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid history format for user {user_id}, "
|
f"Invalid history format for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ class SessionService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to retrieve history for user {user_id}, "
|
f"Failed to retrieve history for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: {e}",
|
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty list on error to allow execution to continue
|
# Return empty list on error to allow execution to continue
|
||||||
@@ -100,7 +100,7 @@ class SessionService:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
ai_response: str
|
ai_response: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +110,7 @@ class SessionService:
|
|||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
query: User query/message
|
query: User query/message
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
ai_response: AI response/answer
|
ai_response: AI response/answer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -131,7 +131,7 @@ class SessionService:
|
|||||||
userid=user_id,
|
userid=user_id,
|
||||||
messages=query,
|
messages=query,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
aimessages=ai_response
|
aimessages=ai_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class SessionService:
|
|||||||
Duplicates are identified by matching:
|
Duplicates are identified by matching:
|
||||||
- sessionid
|
- sessionid
|
||||||
- user_id (id field)
|
- user_id (id field)
|
||||||
- group_id
|
- end_user_id
|
||||||
- messages
|
- messages
|
||||||
- aimessages
|
- aimessages
|
||||||
|
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
|||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
user_id: str = "user1",
|
|
||||||
apply_id: str = "applyid",
|
|
||||||
content: str = "这是用户的输入",
|
content: str = "这是用户的输入",
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "wyl_20251027",
|
||||||
config_id: str = None
|
config_id: str = None
|
||||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
group_id: Group identifier
|
end_user_id: End user identifier
|
||||||
user_id: User identifier
|
|
||||||
apply_id: Application identifier
|
|
||||||
content: Dialog content
|
content: Dialog content
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing
|
config_id: Configuration ID for processing
|
||||||
@@ -37,13 +33,11 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
# Create DialogData
|
# Create DialogData
|
||||||
conversation_context = ConversationContext(msgs=messages)
|
conversation_context = ConversationContext(msgs=messages)
|
||||||
# Create DialogData with group_id based on the entry's id for uniqueness
|
# Create DialogData with end_user_id
|
||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
|
||||||
apply_id=apply_id,
|
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
# Create DialogueChunker and process the dialogue
|
# Create DialogueChunker and process the dialogue
|
||||||
|
|||||||
@@ -12,13 +12,11 @@ class WriteState(TypedDict):
|
|||||||
Langgrapg Writing TypedDict
|
Langgrapg Writing TypedDict
|
||||||
'''
|
'''
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
user_id:str
|
end_user_id: str
|
||||||
apply_id:str
|
|
||||||
group_id:str
|
|
||||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||||
memory_config: object
|
memory_config: object
|
||||||
write_result: dict
|
write_result: dict
|
||||||
data:str
|
data: str
|
||||||
|
|
||||||
class ReadState(TypedDict):
|
class ReadState(TypedDict):
|
||||||
"""
|
"""
|
||||||
@@ -28,7 +26,7 @@ class ReadState(TypedDict):
|
|||||||
messages: 消息列表,支持自动追加
|
messages: 消息列表,支持自动追加
|
||||||
loop_count: 遍历次数
|
loop_count: 遍历次数
|
||||||
search_switch: 搜索类型开关
|
search_switch: 搜索类型开关
|
||||||
group_id: 组标识
|
end_user_id: 组标识
|
||||||
config_id: 配置ID,用于过滤结果
|
config_id: 配置ID,用于过滤结果
|
||||||
data: 从content_input_node传递的内容数据
|
data: 从content_input_node传递的内容数据
|
||||||
spit_data: 从Split_The_Problem传递的分解结果
|
spit_data: 从Split_The_Problem传递的分解结果
|
||||||
@@ -39,7 +37,7 @@ class ReadState(TypedDict):
|
|||||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||||
loop_count: int
|
loop_count: int
|
||||||
search_switch: str
|
search_switch: str
|
||||||
group_id: str
|
end_user_id: str
|
||||||
config_id: str
|
config_id: str
|
||||||
data: str # 新增字段用于传递内容
|
data: str # 新增字段用于传递内容
|
||||||
spit_data: dict # 新增字段用于传递问题分解结果
|
spit_data: dict # 新增字段用于传递问题分解结果
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class RedisSessionStore:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
# 修改后的 save_session 方法
|
# 修改后的 save_session 方法
|
||||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
优化版本:确保写入时间不超过1秒
|
优化版本:确保写入时间不超过1秒
|
||||||
@@ -46,7 +46,7 @@ class RedisSessionStore:
|
|||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": userid,
|
"sessionid": userid,
|
||||||
"apply_id": apply_id,
|
"apply_id": apply_id,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"aimessages": aimessages,
|
"aimessages": aimessages,
|
||||||
"starttime": starttime
|
"starttime": starttime
|
||||||
@@ -67,7 +67,7 @@ class RedisSessionStore:
|
|||||||
def save_sessions_batch(self, sessions_data):
|
def save_sessions_batch(self, sessions_data):
|
||||||
"""
|
"""
|
||||||
批量写入多条会话数据,返回 session_id 列表
|
批量写入多条会话数据,返回 session_id 列表
|
||||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||||
优化版本:批量操作,大幅提升性能
|
优化版本:批量操作,大幅提升性能
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -83,7 +83,7 @@ class RedisSessionStore:
|
|||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": session.get('userid'),
|
"sessionid": session.get('userid'),
|
||||||
"apply_id": session.get('apply_id'),
|
"apply_id": session.get('apply_id'),
|
||||||
"group_id": session.get('group_id'),
|
"end_user_id": session.get('end_user_id'),
|
||||||
"messages": session.get('messages'),
|
"messages": session.get('messages'),
|
||||||
"aimessages": session.get('aimessages'),
|
"aimessages": session.get('aimessages'),
|
||||||
"starttime": starttime
|
"starttime": starttime
|
||||||
@@ -108,9 +108,9 @@ class RedisSessionStore:
|
|||||||
data = self.r.hgetall(key)
|
data = self.r.hgetall(key)
|
||||||
return data if data else None
|
return data if data else None
|
||||||
|
|
||||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||||
"""
|
"""
|
||||||
result_items = []
|
result_items = []
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ class RedisSessionStore:
|
|||||||
# 检查三个条件是否都匹配
|
# 检查三个条件是否都匹配
|
||||||
if (data.get('sessionid') == sessionid and
|
if (data.get('sessionid') == sessionid and
|
||||||
data.get('apply_id') == apply_id and
|
data.get('apply_id') == apply_id and
|
||||||
data.get('group_id') == group_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
result_items.append(data)
|
result_items.append(data)
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -172,7 +172,7 @@ class RedisSessionStore:
|
|||||||
def delete_duplicate_sessions(self):
|
def delete_duplicate_sessions(self):
|
||||||
"""
|
"""
|
||||||
删除重复会话数据,条件:
|
删除重复会话数据,条件:
|
||||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
@@ -202,12 +202,12 @@ class RedisSessionStore:
|
|||||||
# 获取五个字段的值
|
# 获取五个字段的值
|
||||||
sessionid = data.get('sessionid', '')
|
sessionid = data.get('sessionid', '')
|
||||||
user_id = data.get('id', '')
|
user_id = data.get('id', '')
|
||||||
group_id = data.get('group_id', '')
|
end_user_id = data.get('end_user_id', '')
|
||||||
messages = data.get('messages', '')
|
messages = data.get('messages', '')
|
||||||
aimessages = data.get('aimessages', '')
|
aimessages = data.get('aimessages', '')
|
||||||
|
|
||||||
# 用五元组作为唯一标识
|
# 用五元组作为唯一标识
|
||||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||||
|
|
||||||
if identifier in seen:
|
if identifier in seen:
|
||||||
# 重复,标记为待删除
|
# 重复,标记为待删除
|
||||||
@@ -248,9 +248,9 @@ class RedisSessionStore:
|
|||||||
result_items = []
|
result_items = []
|
||||||
return (result_items)
|
return (result_items)
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -276,7 +276,7 @@ class RedisSessionStore:
|
|||||||
# 检查是否符合三个条件
|
# 检查是否符合三个条件
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
if (data.get('apply_id') == apply_id and
|
||||||
data.get('group_id') == group_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
# 支持模糊匹配 sessionid 或者完全匹配
|
# 支持模糊匹配 sessionid 或者完全匹配
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
matched_items.append({
|
matched_items.append({
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class SessionService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str
|
end_user_id: str
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve conversation history from Redis.
|
Retrieve conversation history from Redis.
|
||||||
@@ -67,20 +67,20 @@ class SessionService:
|
|||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of conversation history items with Query and Answer keys
|
List of conversation history items with Query and Answer keys
|
||||||
Returns empty list if no history found or on error
|
Returns empty list if no history found or on error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||||
|
|
||||||
# Validate history structure
|
# Validate history structure
|
||||||
if not isinstance(history, list):
|
if not isinstance(history, list):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid history format for user {user_id}, "
|
f"Invalid history format for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ class SessionService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to retrieve history for user {user_id}, "
|
f"Failed to retrieve history for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: {e}",
|
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty list on error to allow execution to continue
|
# Return empty list on error to allow execution to continue
|
||||||
@@ -100,7 +100,7 @@ class SessionService:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
ai_response: str
|
ai_response: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +110,7 @@ class SessionService:
|
|||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
query: User query/message
|
query: User query/message
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
ai_response: AI response/answer
|
ai_response: AI response/answer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -131,7 +131,7 @@ class SessionService:
|
|||||||
userid=user_id,
|
userid=user_id,
|
||||||
messages=query,
|
messages=query,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
aimessages=ai_response
|
aimessages=ai_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class SessionService:
|
|||||||
Duplicates are identified by matching:
|
Duplicates are identified by matching:
|
||||||
- sessionid
|
- sessionid
|
||||||
- user_id (id field)
|
- user_id (id field)
|
||||||
- group_id
|
- end_user_id
|
||||||
- messages
|
- messages
|
||||||
- aimessages
|
- aimessages
|
||||||
|
|
||||||
|
|||||||
@@ -29,25 +29,18 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
content: str,
|
end_user_id: str,
|
||||||
user_id: str,
|
|
||||||
apply_id: str,
|
|
||||||
group_id: str,
|
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
ref_id: str = "wyl20251027",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
|
|
||||||
Only MemoryConfig is needed - LLM and embedding clients are constructed
|
|
||||||
internally from the config.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Dialogue content to process
|
end_user_id: End user identifier
|
||||||
user_id: User identifier
|
|
||||||
apply_id: Application identifier
|
|
||||||
group_id: Group identifier
|
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
ref_id: Reference ID, defaults to "wyl20251027"
|
||||||
"""
|
"""
|
||||||
# Extract config values
|
# Extract config values
|
||||||
@@ -61,7 +54,7 @@ async def write(
|
|||||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||||
logger.info(f"Group ID: {group_id}")
|
logger.info(f"End User ID: {end_user_id}")
|
||||||
|
|
||||||
# Construct clients from memory_config using factory pattern with db session
|
# Construct clients from memory_config using factory pattern with db session
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
@@ -84,12 +77,25 @@ async def write(
|
|||||||
|
|
||||||
# Step 1: Load and chunk data
|
# Step 1: Load and chunk data
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
|
|
||||||
|
# Convert messages list to content string
|
||||||
|
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
||||||
|
if isinstance(messages, list) and len(messages) > 0:
|
||||||
|
# Extract content from the last user message or concatenate all messages
|
||||||
|
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
|
||||||
|
content = messages[-1]['content']
|
||||||
|
else:
|
||||||
|
# Fallback: concatenate all message contents
|
||||||
|
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
|
||||||
|
elif isinstance(messages, str):
|
||||||
|
content = messages
|
||||||
|
else:
|
||||||
|
content = str(messages)
|
||||||
|
|
||||||
chunked_dialogs = await get_chunked_dialogs(
|
chunked_dialogs = await get_chunked_dialogs(
|
||||||
chunker_strategy=chunker_strategy,
|
chunker_strategy=chunker_strategy,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
content=content, # 修复:使用 content 参数而不是 messages
|
||||||
apply_id=apply_id,
|
|
||||||
content=content,
|
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
|||||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||||
|
|
||||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tags: 原始标签列表
|
tags: 原始标签列表
|
||||||
group_id: 用户组ID,用于获取配置
|
end_user_id: 用户组ID,用于获取配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
筛选后的标签列表
|
筛选后的标签列表
|
||||||
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
|||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
|
||||||
if not config_id:
|
if not config_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No memory_config_id found for group_id: {group_id}. "
|
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||||
"Please ensure the user has a valid memory configuration."
|
"Please ensure the user has a valid memory configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
|||||||
|
|
||||||
async def get_raw_tags_from_db(
|
async def get_raw_tags_from_db(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
limit: int,
|
limit: int,
|
||||||
by_user: bool = False
|
by_user: bool = False
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j连接器实例
|
connector: Neo4j连接器实例
|
||||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
limit: 返回的标签数量限制
|
limit: 返回的标签数量限制
|
||||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||||
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
|
|||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"MATCH (e:ExtractedEntity) "
|
"MATCH (e:ExtractedEntity) "
|
||||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||||
"RETURN e.name AS name, count(e) AS frequency "
|
"RETURN e.name AS name, count(e) AS frequency "
|
||||||
"ORDER BY frequency DESC "
|
"ORDER BY frequency DESC "
|
||||||
"LIMIT $limit"
|
"LIMIT $limit"
|
||||||
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
|
|||||||
# 使用项目的Neo4jConnector执行查询
|
# 使用项目的Neo4jConnector执行查询
|
||||||
results = await connector.execute_query(
|
results = await connector.execute_query(
|
||||||
query,
|
query,
|
||||||
id=group_id,
|
id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
names_to_exclude=names_to_exclude
|
names_to_exclude=names_to_exclude
|
||||||
)
|
)
|
||||||
|
|
||||||
return [(record["name"], record["frequency"]) for record in results]
|
return [(record["name"], record["frequency"]) for record in results]
|
||||||
|
|
||||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
limit: 返回的标签数量限制
|
limit: 返回的标签数量限制
|
||||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果group_id未提供或为空
|
ValueError: 如果end_user_id未提供或为空
|
||||||
"""
|
"""
|
||||||
# 验证group_id必须提供且不为空
|
# 验证end_user_id必须提供且不为空
|
||||||
if not group_id or not group_id.strip():
|
if not end_user_id or not end_user_id.strip():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"group_id is required. Please provide a valid group_id or user_id."
|
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用项目的Neo4jConnector
|
# 使用项目的Neo4jConnector
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
# 1. 从数据库获取原始排名靠前的标签
|
# 1. 从数据库获取原始排名靠前的标签
|
||||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||||
if not raw_tags_with_freq:
|
if not raw_tags_with_freq:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||||
|
|
||||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||||
|
|
||||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||||
final_tags = []
|
final_tags = []
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
|||||||
start_date = time_range.start_date if time_range else None
|
start_date = time_range.start_date if time_range else None
|
||||||
end_date = time_range.end_date if time_range else None
|
end_date = time_range.end_date if time_range else None
|
||||||
|
|
||||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||||
group_id=user_id,
|
end_user_id=user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date
|
end_date=end_date
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """
|
|||||||
WITH $embedding AS q
|
WITH $embedding AS q
|
||||||
MATCH (d:Dialogue)
|
MATCH (d:Dialogue)
|
||||||
WHERE d.dialog_embedding IS NOT NULL
|
WHERE d.dialog_embedding IS NOT NULL
|
||||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
||||||
WITH d, q, d.dialog_embedding AS v
|
WITH d, q, d.dialog_embedding AS v
|
||||||
WITH d,
|
WITH d,
|
||||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||||
@@ -50,7 +50,7 @@ WITH d,
|
|||||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||||
WHERE score > $threshold
|
WHERE score > $threshold
|
||||||
RETURN d.id AS dialog_id,
|
RETURN d.id AS dialog_id,
|
||||||
d.group_id AS group_id,
|
d.end_user_id AS end_user_id,
|
||||||
d.content AS content,
|
d.content AS content,
|
||||||
d.created_at AS created_at,
|
d.created_at AS created_at,
|
||||||
d.expired_at AS expired_at,
|
d.expired_at AS expired_at,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
|
|
||||||
async def ingest_contexts_via_full_pipeline(
|
async def ingest_contexts_via_full_pipeline(
|
||||||
contexts: List[str],
|
contexts: List[str],
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
chunker_strategy: str | None = None,
|
chunker_strategy: str | None = None,
|
||||||
embedding_name: str | None = None,
|
embedding_name: str | None = None,
|
||||||
save_chunk_output: bool = False,
|
save_chunk_output: bool = False,
|
||||||
@@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline(
|
|||||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||||
Args:
|
Args:
|
||||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
end_user_id: Group ID to assign to generated DialogData and graph nodes.
|
||||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||||
@@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline(
|
|||||||
dialog = DialogData(
|
dialog = DialogData(
|
||||||
context=context_model,
|
context=context_model,
|
||||||
ref_id=f"pipeline_item_{idx}",
|
ref_id=f"pipeline_item_{idx}",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id="default_user",
|
user_id="default_user",
|
||||||
apply_id="default_application",
|
apply_id="default_application",
|
||||||
)
|
)
|
||||||
@@ -318,16 +318,16 @@ async def handle_context_processing(args):
|
|||||||
print("No contexts provided for processing.")
|
print("No contexts provided for processing.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await main_from_contexts(contexts, args.context_group_id)
|
return await main_from_contexts(contexts, args.context_end_user_id)
|
||||||
|
|
||||||
|
|
||||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
async def main_from_contexts(contexts: List[str], end_user_id: str):
|
||||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||||
print("=== Running pipeline from provided contexts ===")
|
print("=== Running pipeline from provided contexts ===")
|
||||||
|
|
||||||
success = await ingest_contexts_via_full_pipeline(
|
success = await ingest_contexts_via_full_pipeline(
|
||||||
contexts=contexts,
|
contexts=contexts,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||||
embedding_name=SELECTED_EMBEDDING_ID,
|
embedding_name=SELECTED_EMBEDDING_ID,
|
||||||
save_chunk_output=True
|
save_chunk_output=True
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|||||||
from app.core.memory.utils.definitions import (
|
from app.core.memory.utils.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_GROUP_ID,
|
SELECTED_end_user_id,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService
|
|||||||
|
|
||||||
async def run_locomo_benchmark(
|
async def run_locomo_benchmark(
|
||||||
sample_size: int = 20,
|
sample_size: int = 20,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
search_limit: int = 12,
|
search_limit: int = 12,
|
||||||
context_char_budget: int = 8000,
|
context_char_budget: int = 8000,
|
||||||
@@ -85,7 +85,7 @@ async def run_locomo_benchmark(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||||
group_id: Database group ID for retrieval (uses default if None)
|
end_user_id: Database group ID for retrieval (uses default if None)
|
||||||
search_type: "keyword", "embedding", or "hybrid"
|
search_type: "keyword", "embedding", or "hybrid"
|
||||||
search_limit: Max documents to retrieve per query
|
search_limit: Max documents to retrieve per query
|
||||||
context_char_budget: Max characters for context
|
context_char_budget: Max characters for context
|
||||||
@@ -96,8 +96,8 @@ async def run_locomo_benchmark(
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with evaluation results including metrics, timing, and samples
|
Dictionary with evaluation results including metrics, timing, and samples
|
||||||
"""
|
"""
|
||||||
# Use default group_id if not provided
|
# Use default end_user_id if not provided
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
end_user_id = end_user_id or SELECTED_end_user_id
|
||||||
|
|
||||||
# Determine data path
|
# Determine data path
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||||
@@ -110,7 +110,7 @@ async def run_locomo_benchmark(
|
|||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print("📊 Configuration:")
|
print("📊 Configuration:")
|
||||||
print(f" Sample size: {sample_size}")
|
print(f" Sample size: {sample_size}")
|
||||||
print(f" Group ID: {group_id}")
|
print(f" Group ID: {end_user_id}")
|
||||||
print(f" Search type: {search_type}")
|
print(f" Search type: {search_type}")
|
||||||
print(f" Search limit: {search_limit}")
|
print(f" Search limit: {search_limit}")
|
||||||
print(f" Context budget: {context_char_budget} chars")
|
print(f" Context budget: {context_char_budget} chars")
|
||||||
@@ -134,7 +134,7 @@ async def run_locomo_benchmark(
|
|||||||
# Step 2: Extract conversations and ingest if needed
|
# Step 2: Extract conversations and ingest if needed
|
||||||
if skip_ingest:
|
if skip_ingest:
|
||||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||||
print(f" Group ID: {group_id}\n")
|
print(f" Group ID: {end_user_id}\n")
|
||||||
else:
|
else:
|
||||||
print("💾 Checking database ingestion...")
|
print("💾 Checking database ingestion...")
|
||||||
try:
|
try:
|
||||||
@@ -142,10 +142,10 @@ async def run_locomo_benchmark(
|
|||||||
print(f"📝 Extracted {len(conversations)} conversations")
|
print(f"📝 Extracted {len(conversations)} conversations")
|
||||||
|
|
||||||
# Always ingest for now (ingestion check not implemented)
|
# Always ingest for now (ingestion check not implemented)
|
||||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
|
||||||
success = await ingest_conversations_if_needed(
|
success = await ingest_conversations_if_needed(
|
||||||
conversations=conversations,
|
conversations=conversations,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
reset=reset_group
|
reset=reset_group
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ async def run_locomo_benchmark(
|
|||||||
try:
|
try:
|
||||||
retrieved_info = await retrieve_relevant_information(
|
retrieved_info = await retrieve_relevant_information(
|
||||||
question=question,
|
question=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
search_limit=search_limit,
|
search_limit=search_limit,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
@@ -409,7 +409,7 @@ async def run_locomo_benchmark(
|
|||||||
"sample_size": len(qa_items),
|
"sample_size": len(qa_items),
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
"params": {
|
"params": {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"search_limit": search_limit,
|
"search_limit": search_limit,
|
||||||
"context_char_budget": context_char_budget,
|
"context_char_budget": context_char_budget,
|
||||||
@@ -467,7 +467,7 @@ def main():
|
|||||||
help="Number of QA pairs to evaluate"
|
help="Number of QA pairs to evaluate"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group_id",
|
"--end_user_id",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Database group ID for retrieval (uses default if not specified)"
|
help="Database group ID for retrieval (uses default if not specified)"
|
||||||
@@ -516,7 +516,7 @@ def main():
|
|||||||
# Run benchmark
|
# Run benchmark
|
||||||
result = asyncio.run(run_locomo_benchmark(
|
result = asyncio.run(run_locomo_benchmark(
|
||||||
sample_size=args.sample_size,
|
sample_size=args.sample_size,
|
||||||
group_id=args.group_id,
|
end_user_id=args.end_user_id,
|
||||||
search_type=args.search_type,
|
search_type=args.search_type,
|
||||||
search_limit=args.search_limit,
|
search_limit=args.search_limit,
|
||||||
context_char_budget=args.context_char_budget,
|
context_char_budget=args.context_char_budget,
|
||||||
|
|||||||
@@ -555,7 +555,7 @@ async def run_enhanced_evaluation():
|
|||||||
search_results = await run_hybrid_search(
|
search_results = await run_hybrid_search(
|
||||||
query_text=q,
|
query_text=q,
|
||||||
search_type="hybrid",
|
search_type="hybrid",
|
||||||
group_id="locomo_sk",
|
end_user_id="locomo_sk",
|
||||||
limit=20,
|
limit=20,
|
||||||
include=["statements", "chunks", "entities", "summaries"],
|
include=["statements", "chunks", "entities", "summaries"],
|
||||||
alpha=0.6, # BM25权重
|
alpha=0.6, # BM25权重
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ def select_and_format_information(
|
|||||||
|
|
||||||
async def retrieve_relevant_information(
|
async def retrieve_relevant_information(
|
||||||
question: str,
|
question: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
search_limit: int,
|
search_limit: int,
|
||||||
connector: Any,
|
connector: Any,
|
||||||
@@ -368,7 +368,7 @@ async def retrieve_relevant_information(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: Question to search for
|
question: Question to search for
|
||||||
group_id: Database group ID (identifies which conversation memory to search)
|
end_user_id: Database group ID (identifies which conversation memory to search)
|
||||||
search_type: "keyword", "embedding", or "hybrid"
|
search_type: "keyword", "embedding", or "hybrid"
|
||||||
search_limit: Max memory pieces to retrieve
|
search_limit: Max memory pieces to retrieve
|
||||||
connector: Neo4j connector instance
|
connector: Neo4j connector instance
|
||||||
@@ -396,7 +396,7 @@ async def retrieve_relevant_information(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
)
|
)
|
||||||
@@ -455,7 +455,7 @@ async def retrieve_relevant_information(
|
|||||||
search_results = await search_graph(
|
search_results = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit
|
limit=search_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -491,7 +491,7 @@ async def retrieve_relevant_information(
|
|||||||
search_results = await run_hybrid_search(
|
search_results = await run_hybrid_search(
|
||||||
query_text=question,
|
query_text=question,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
output_path=None,
|
output_path=None,
|
||||||
@@ -524,7 +524,7 @@ async def retrieve_relevant_information(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
)
|
)
|
||||||
@@ -584,7 +584,7 @@ async def retrieve_relevant_information(
|
|||||||
|
|
||||||
async def ingest_conversations_if_needed(
|
async def ingest_conversations_if_needed(
|
||||||
conversations: List[str],
|
conversations: List[str],
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
reset: bool = False
|
reset: bool = False
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -603,7 +603,7 @@ async def ingest_conversations_if_needed(
|
|||||||
Args:
|
Args:
|
||||||
conversations: List of raw conversation texts from LoCoMo dataset
|
conversations: List of raw conversation texts from LoCoMo dataset
|
||||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||||
group_id: Target group ID for database storage
|
end_user_id: Target group ID for database storage
|
||||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -617,7 +617,7 @@ async def ingest_conversations_if_needed(
|
|||||||
try:
|
try:
|
||||||
success = await ingest_contexts_via_full_pipeline(
|
success = await ingest_contexts_via_full_pipeline(
|
||||||
contexts=conversations,
|
contexts=conversations,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
save_chunk_output=True
|
save_chunk_output=True
|
||||||
)
|
)
|
||||||
return success
|
return success
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_GROUP_ID,
|
SELECTED_end_user_id,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -249,7 +249,7 @@ def get_search_params_by_category(category: str):
|
|||||||
|
|
||||||
async def run_locomo_eval(
|
async def run_locomo_eval(
|
||||||
sample_size: int = 1,
|
sample_size: int = 1,
|
||||||
group_id: str | None = None,
|
end_user_id: str | None = None,
|
||||||
search_limit: int = 8,
|
search_limit: int = 8,
|
||||||
context_char_budget: int = 4000, # 保持默认值不变
|
context_char_budget: int = 4000, # 保持默认值不变
|
||||||
llm_temperature: float = 0.0,
|
llm_temperature: float = 0.0,
|
||||||
@@ -262,7 +262,7 @@ async def run_locomo_eval(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
end_user_id = end_user_id or SELECTED_end_user_id
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||||
@@ -340,7 +340,7 @@ async def run_locomo_eval(
|
|||||||
|
|
||||||
# 关键修复:强制重新摄入纯净的对话数据
|
# 关键修复:强制重新摄入纯净的对话数据
|
||||||
print("🔄 强制重新摄入纯净的对话数据...")
|
print("🔄 强制重新摄入纯净的对话数据...")
|
||||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
|
||||||
|
|
||||||
# 使用异步LLM客户端
|
# 使用异步LLM客户端
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
@@ -405,7 +405,7 @@ async def run_locomo_eval(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=q,
|
query_text=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=adjusted_limit,
|
limit=adjusted_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||||
)
|
)
|
||||||
@@ -456,7 +456,7 @@ async def run_locomo_eval(
|
|||||||
search_results = await search_graph(
|
search_results = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=q,
|
q=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=adjusted_limit
|
limit=adjusted_limit
|
||||||
)
|
)
|
||||||
dialogs = search_results.get("dialogues", [])
|
dialogs = search_results.get("dialogues", [])
|
||||||
@@ -486,7 +486,7 @@ async def run_locomo_eval(
|
|||||||
search_results = await run_hybrid_search(
|
search_results = await run_hybrid_search(
|
||||||
query_text=q,
|
query_text=q,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=adjusted_limit,
|
limit=adjusted_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
output_path=None,
|
output_path=None,
|
||||||
@@ -524,7 +524,7 @@ async def run_locomo_eval(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=q,
|
query_text=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=adjusted_limit,
|
limit=adjusted_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
)
|
)
|
||||||
@@ -597,7 +597,7 @@ async def run_locomo_eval(
|
|||||||
"dialogues": [
|
"dialogues": [
|
||||||
{
|
{
|
||||||
"uuid": d.get("uuid", ""),
|
"uuid": d.get("uuid", ""),
|
||||||
"group_id": d.get("group_id", ""),
|
"end_user_id": d.get("end_user_id", ""),
|
||||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||||
"score": d.get("score", 0.0)
|
"score": d.get("score", 0.0)
|
||||||
}
|
}
|
||||||
@@ -795,7 +795,7 @@ async def run_locomo_eval(
|
|||||||
},
|
},
|
||||||
"samples": samples,
|
"samples": samples,
|
||||||
"params": {
|
"params": {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"search_limit": search_limit,
|
"search_limit": search_limit,
|
||||||
"context_char_budget": context_char_budget,
|
"context_char_budget": context_char_budget,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
@@ -825,7 +825,7 @@ async def run_locomo_eval(
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
|
||||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||||
@@ -841,7 +841,7 @@ def main():
|
|||||||
|
|
||||||
result = asyncio.run(run_locomo_eval(
|
result = asyncio.run(run_locomo_eval(
|
||||||
sample_size=args.sample_size,
|
sample_size=args.sample_size,
|
||||||
group_id=args.group_id,
|
end_user_id=args.end_user_id,
|
||||||
search_limit=args.search_limit,
|
search_limit=args.search_limit,
|
||||||
context_char_budget=args.context_char_budget,
|
context_char_budget=args.context_char_budget,
|
||||||
llm_temperature=args.llm_temperature,
|
llm_temperature=args.llm_temperature,
|
||||||
|
|||||||
@@ -523,11 +523,11 @@ def generate_query_keywords_cn(question: str) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
# 通过别名匹配进行实体关键词检索(多token合并)
|
# 通过别名匹配进行实体关键词检索(多token合并)
|
||||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Dict[str, Any]] = []
|
||||||
try:
|
try:
|
||||||
for tok in tokens:
|
for tok in tokens:
|
||||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
|
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
|
||||||
if rows:
|
if rows:
|
||||||
results.extend(rows)
|
results.extend(rows)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -547,15 +547,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
|
|||||||
# 通过对话/陈述中的entity_ids反查实体名称
|
# 通过对话/陈述中的entity_ids反查实体名称
|
||||||
_FETCH_ENTITIES_BY_IDS = """
|
_FETCH_ENTITIES_BY_IDS = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
|
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
|
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
|
||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
|
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
|
||||||
return rows or []
|
return rows or []
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
@@ -565,18 +565,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
|
|||||||
_TIME_ENTITY_SEARCH = """
|
_TIME_ENTITY_SEARCH = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
||||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
"""专门搜索时间相关的实体"""
|
"""专门搜索时间相关的实体"""
|
||||||
try:
|
try:
|
||||||
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
||||||
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
||||||
date_pattern=date_pattern,
|
date_pattern=date_pattern,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit)
|
limit=limit)
|
||||||
return rows or []
|
return rows or []
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -623,7 +623,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
|
|||||||
|
|
||||||
async def run_longmemeval_test(
|
async def run_longmemeval_test(
|
||||||
sample_size: int = 3,
|
sample_size: int = 3,
|
||||||
group_id: str = "longmemeval_zh_bak_3",
|
end_user_id: str = "longmemeval_zh_bak_3",
|
||||||
search_limit: int = 8,
|
search_limit: int = 8,
|
||||||
context_char_budget: int = 4000,
|
context_char_budget: int = 4000,
|
||||||
llm_temperature: float = 0.0,
|
llm_temperature: float = 0.0,
|
||||||
@@ -677,13 +677,13 @@ async def run_longmemeval_test(
|
|||||||
contexts.extend(selected)
|
contexts.extend(selected)
|
||||||
|
|
||||||
print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
|
print(f"📥 摄入 {len(contexts)} 个上下文到数据库")
|
||||||
if reset_group_before_ingest and group_id:
|
if reset_group_before_ingest and end_user_id:
|
||||||
try:
|
try:
|
||||||
_tmp_conn = Neo4jConnector()
|
_tmp_conn = Neo4jConnector()
|
||||||
await _tmp_conn.delete_group(group_id)
|
await _tmp_conn.delete_group(end_user_id)
|
||||||
print(f"🧹 已清空组 {group_id} 的历史图数据")
|
print(f"🧹 已清空组 {end_user_id} 的历史图数据")
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}")
|
print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}")
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await _tmp_conn.close()
|
await _tmp_conn.close()
|
||||||
@@ -695,7 +695,7 @@ async def run_longmemeval_test(
|
|||||||
else:
|
else:
|
||||||
await _ingest_fn(
|
await _ingest_fn(
|
||||||
contexts,
|
contexts,
|
||||||
group_id,
|
end_user_id,
|
||||||
save_chunk_output=save_chunk_output,
|
save_chunk_output=save_chunk_output,
|
||||||
save_chunk_output_path=save_chunk_output_path,
|
save_chunk_output_path=save_chunk_output_path,
|
||||||
)
|
)
|
||||||
@@ -750,7 +750,7 @@ async def run_longmemeval_test(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
)
|
)
|
||||||
@@ -795,7 +795,7 @@ async def run_longmemeval_test(
|
|||||||
search_results = await search_graph(
|
search_results = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
)
|
)
|
||||||
chunks = search_results.get("chunks", [])
|
chunks = search_results.get("chunks", [])
|
||||||
@@ -830,7 +830,7 @@ async def run_longmemeval_test(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
include=["chunks", "statements", "entities", "summaries"],
|
||||||
)
|
)
|
||||||
@@ -848,7 +848,7 @@ async def run_longmemeval_test(
|
|||||||
kw_res = await search_graph(
|
kw_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
)
|
)
|
||||||
if isinstance(kw_res, dict):
|
if isinstance(kw_res, dict):
|
||||||
@@ -859,7 +859,7 @@ async def run_longmemeval_test(
|
|||||||
# 时间推理问题的特殊处理
|
# 时间推理问题的特殊处理
|
||||||
if is_temporal:
|
if is_temporal:
|
||||||
# 专门搜索时间实体
|
# 专门搜索时间实体
|
||||||
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
|
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
|
||||||
if time_entities:
|
if time_entities:
|
||||||
kw_entities.extend(time_entities)
|
kw_entities.extend(time_entities)
|
||||||
# 添加时间相关关键词检索
|
# 添加时间相关关键词检索
|
||||||
@@ -869,7 +869,7 @@ async def run_longmemeval_test(
|
|||||||
time_res = await search_graph(
|
time_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=tk,
|
q=tk,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=2,
|
limit=2,
|
||||||
)
|
)
|
||||||
if isinstance(time_res, dict):
|
if isinstance(time_res, dict):
|
||||||
@@ -880,7 +880,7 @@ async def run_longmemeval_test(
|
|||||||
|
|
||||||
# 中文关键词拆分后做别名匹配
|
# 中文关键词拆分后做别名匹配
|
||||||
cn_tokens = _extract_cn_tokens(question)
|
cn_tokens = _extract_cn_tokens(question)
|
||||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
|
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
|
||||||
if alias_entities:
|
if alias_entities:
|
||||||
kw_entities.extend(alias_entities)
|
kw_entities.extend(alias_entities)
|
||||||
|
|
||||||
@@ -894,7 +894,7 @@ async def run_longmemeval_test(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if ids:
|
if ids:
|
||||||
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
|
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
|
||||||
if id_entities:
|
if id_entities:
|
||||||
kw_entities.extend(id_entities)
|
kw_entities.extend(id_entities)
|
||||||
|
|
||||||
@@ -908,7 +908,7 @@ async def run_longmemeval_test(
|
|||||||
sub_res = await search_graph(
|
sub_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=str(kw),
|
q=str(kw),
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=max(3, search_limit // 2),
|
limit=max(3, search_limit // 2),
|
||||||
)
|
)
|
||||||
if isinstance(sub_res, dict):
|
if isinstance(sub_res, dict):
|
||||||
@@ -927,7 +927,7 @@ async def run_longmemeval_test(
|
|||||||
opt_res = await search_graph(
|
opt_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=str(opt),
|
q=str(opt),
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=max(3, search_limit // 2),
|
limit=max(3, search_limit // 2),
|
||||||
)
|
)
|
||||||
if isinstance(opt_res, dict):
|
if isinstance(opt_res, dict):
|
||||||
|
|||||||
@@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int =
|
|||||||
|
|
||||||
|
|
||||||
# 通过别名匹配进行实体关键词检索(多token合并)
|
# 通过别名匹配进行实体关键词检索(多token合并)
|
||||||
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]:
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Dict[str, Any]] = []
|
||||||
try:
|
try:
|
||||||
for tok in tokens:
|
for tok in tokens:
|
||||||
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit)
|
rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit)
|
||||||
if rows:
|
if rows:
|
||||||
results.extend(rows)
|
results.extend(rows)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st
|
|||||||
# 通过对话/陈述中的entity_ids反查实体名称
|
# 通过对话/陈述中的entity_ids反查实体名称
|
||||||
_FETCH_ENTITIES_BY_IDS = """
|
_FETCH_ENTITIES_BY_IDS = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id)
|
WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]:
|
async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]:
|
||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id)
|
rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id)
|
||||||
return rows or []
|
return rows or []
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
@@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou
|
|||||||
_TIME_ENTITY_SEARCH = """
|
_TIME_ENTITY_SEARCH = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern
|
||||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type
|
RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
"""专门搜索时间相关的实体"""
|
"""专门搜索时间相关的实体"""
|
||||||
try:
|
try:
|
||||||
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*"
|
||||||
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
rows = await connector.execute_query(_TIME_ENTITY_SEARCH,
|
||||||
date_pattern=date_pattern,
|
date_pattern=date_pattern,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit)
|
limit=limit)
|
||||||
return rows or []
|
return rows or []
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None,
|
|||||||
|
|
||||||
|
|
||||||
# 技术术语专门检索
|
# 技术术语专门检索
|
||||||
async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
|
async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]:
|
||||||
"""专门搜索技术术语相关的实体"""
|
"""专门搜索技术术语相关的实体"""
|
||||||
tech_entities = []
|
tech_entities = []
|
||||||
try:
|
try:
|
||||||
# GPS相关
|
# GPS相关
|
||||||
if any(term in question for term in ["GPS", "导航", "定位系统"]):
|
if any(term in question for term in ["GPS", "导航", "定位系统"]):
|
||||||
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit)
|
gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit)
|
||||||
if gps_rows:
|
if gps_rows:
|
||||||
tech_entities.extend(gps_rows)
|
tech_entities.extend(gps_rows)
|
||||||
|
|
||||||
# 活动相关
|
# 活动相关
|
||||||
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
|
if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]):
|
||||||
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit)
|
workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit)
|
||||||
if workshop_rows:
|
if workshop_rows:
|
||||||
tech_entities.extend(workshop_rows)
|
tech_entities.extend(workshop_rows)
|
||||||
|
|
||||||
# 时间顺序相关
|
# 时间顺序相关
|
||||||
if any(term in question for term in ["先", "后", "第一个"]):
|
if any(term in question for term in ["先", "后", "第一个"]):
|
||||||
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit)
|
time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit)
|
||||||
if time_rows:
|
if time_rows:
|
||||||
tech_entities.extend(time_rows)
|
tech_entities.extend(time_rows)
|
||||||
|
|
||||||
@@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
|
|||||||
|
|
||||||
async def run_longmemeval_test(
|
async def run_longmemeval_test(
|
||||||
sample_size: int = 3,
|
sample_size: int = 3,
|
||||||
group_id: str = "longmemeval_zh_bak_2",
|
end_user_id: str = "longmemeval_zh_bak_2",
|
||||||
search_limit: int = 8,
|
search_limit: int = 8,
|
||||||
context_char_budget: int = 4000,
|
context_char_budget: int = 4000,
|
||||||
llm_temperature: float = 0.0,
|
llm_temperature: float = 0.0,
|
||||||
@@ -707,7 +707,7 @@ async def run_longmemeval_test(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["dialogues", "statements", "entities"],
|
include=["dialogues", "statements", "entities"],
|
||||||
)
|
)
|
||||||
@@ -746,7 +746,7 @@ async def run_longmemeval_test(
|
|||||||
search_results = await search_graph(
|
search_results = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
)
|
)
|
||||||
dialogs = search_results.get("dialogues", [])
|
dialogs = search_results.get("dialogues", [])
|
||||||
@@ -776,7 +776,7 @@ async def run_longmemeval_test(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["dialogues", "statements", "entities"],
|
include=["dialogues", "statements", "entities"],
|
||||||
)
|
)
|
||||||
@@ -792,7 +792,7 @@ async def run_longmemeval_test(
|
|||||||
kw_res = await search_graph(
|
kw_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
)
|
)
|
||||||
if isinstance(kw_res, dict):
|
if isinstance(kw_res, dict):
|
||||||
@@ -801,14 +801,14 @@ async def run_longmemeval_test(
|
|||||||
kw_entities = kw_res.get("entities", []) or []
|
kw_entities = kw_res.get("entities", []) or []
|
||||||
|
|
||||||
# 技术术语专门检索
|
# 技术术语专门检索
|
||||||
tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2)
|
tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2)
|
||||||
if tech_entities:
|
if tech_entities:
|
||||||
kw_entities.extend(tech_entities)
|
kw_entities.extend(tech_entities)
|
||||||
|
|
||||||
# 时间推理问题的特殊处理
|
# 时间推理问题的特殊处理
|
||||||
if is_temporal:
|
if is_temporal:
|
||||||
# 专门搜索时间实体
|
# 专门搜索时间实体
|
||||||
time_entities = await _search_time_entities(connector, group_id, search_limit//2)
|
time_entities = await _search_time_entities(connector, end_user_id, search_limit//2)
|
||||||
if time_entities:
|
if time_entities:
|
||||||
kw_entities.extend(time_entities)
|
kw_entities.extend(time_entities)
|
||||||
# 添加时间相关关键词检索
|
# 添加时间相关关键词检索
|
||||||
@@ -818,7 +818,7 @@ async def run_longmemeval_test(
|
|||||||
time_res = await search_graph(
|
time_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=tk,
|
q=tk,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=2,
|
limit=2,
|
||||||
)
|
)
|
||||||
if isinstance(time_res, dict):
|
if isinstance(time_res, dict):
|
||||||
@@ -829,7 +829,7 @@ async def run_longmemeval_test(
|
|||||||
|
|
||||||
# 中文关键词拆分后做别名匹配
|
# 中文关键词拆分后做别名匹配
|
||||||
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
|
cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取
|
||||||
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit)
|
alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit)
|
||||||
if alias_entities:
|
if alias_entities:
|
||||||
kw_entities.extend(alias_entities)
|
kw_entities.extend(alias_entities)
|
||||||
|
|
||||||
@@ -843,7 +843,7 @@ async def run_longmemeval_test(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if ids:
|
if ids:
|
||||||
id_entities = await _fetch_entities_by_ids(connector, ids, group_id)
|
id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id)
|
||||||
if id_entities:
|
if id_entities:
|
||||||
kw_entities.extend(id_entities)
|
kw_entities.extend(id_entities)
|
||||||
|
|
||||||
@@ -857,7 +857,7 @@ async def run_longmemeval_test(
|
|||||||
sub_res = await search_graph(
|
sub_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=str(kw),
|
q=str(kw),
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=max(3, search_limit // 2),
|
limit=max(3, search_limit // 2),
|
||||||
)
|
)
|
||||||
if isinstance(sub_res, dict):
|
if isinstance(sub_res, dict):
|
||||||
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
|
|||||||
opt_res = await search_graph(
|
opt_res = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=str(opt),
|
q=str(opt),
|
||||||
group_id=group_id,
|
end_user_id=group_id,
|
||||||
limit=max(3, search_limit // 2),
|
limit=max(3, search_limit // 2),
|
||||||
)
|
)
|
||||||
if isinstance(opt_res, dict):
|
if isinstance(opt_res, dict):
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_GROUP_ID,
|
SELECTED_end_user_id,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
end_user_id = end_user_id or SELECTED_end_user_id
|
||||||
# Load data
|
# Load data
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
@@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
|||||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
|
||||||
|
|
||||||
# LLM client (使用异步调用)
|
# LLM client (使用异步调用)
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
@@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
|||||||
results = await run_hybrid_search(
|
results = await run_hybrid_search(
|
||||||
query_text=question,
|
query_text=question,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["dialogues", "statements", "entities"],
|
include=["dialogues", "statements", "entities"],
|
||||||
output_path=None,
|
output_path=None,
|
||||||
@@ -298,7 +298,7 @@ def main():
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||||
@@ -309,7 +309,7 @@ def main():
|
|||||||
result = asyncio.run(
|
result = asyncio.run(
|
||||||
run_memsciqa_eval(
|
run_memsciqa_eval(
|
||||||
sample_size=args.sample_size,
|
sample_size=args.sample_size,
|
||||||
group_id=args.group_id,
|
end_user_id=args.end_user_id,
|
||||||
search_limit=args.search_limit,
|
search_limit=args.search_limit,
|
||||||
context_char_budget=args.context_char_budget,
|
context_char_budget=args.context_char_budget,
|
||||||
llm_temperature=args.llm_temperature,
|
llm_temperature=args.llm_temperature,
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_GROUP_ID,
|
SELECTED_end_user_id,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -198,7 +198,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
async def run_memsciqa_test(
|
async def run_memsciqa_test(
|
||||||
sample_size: int = 3,
|
sample_size: int = 3,
|
||||||
group_id: str | None = None,
|
end_user_id: str | None = None,
|
||||||
search_limit: int = 8,
|
search_limit: int = 8,
|
||||||
context_char_budget: int = 4000,
|
context_char_budget: int = 4000,
|
||||||
llm_temperature: float = 0.0,
|
llm_temperature: float = 0.0,
|
||||||
@@ -216,7 +216,7 @@ async def run_memsciqa_test(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 默认使用指定的 memsci 组 ID
|
# 默认使用指定的 memsci 组 ID
|
||||||
group_id = group_id or "group_memsci"
|
end_user_id = end_user_id or "group_memsci"
|
||||||
|
|
||||||
# 数据路径解析(项目根与当前工作目录兜底)
|
# 数据路径解析(项目根与当前工作目录兜底)
|
||||||
if not data_path:
|
if not data_path:
|
||||||
@@ -282,7 +282,7 @@ async def run_memsciqa_test(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=question,
|
query_text=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||||
)
|
)
|
||||||
@@ -291,7 +291,7 @@ async def run_memsciqa_test(
|
|||||||
results = await search_graph(
|
results = await search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=question,
|
q=question,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=search_limit,
|
limit=search_limit,
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||||
)
|
)
|
||||||
@@ -499,7 +499,7 @@ async def run_memsciqa_test(
|
|||||||
},
|
},
|
||||||
"samples": samples,
|
"samples": samples,
|
||||||
"params": {
|
"params": {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"search_limit": search_limit,
|
"search_limit": search_limit,
|
||||||
"context_char_budget": context_char_budget,
|
"context_char_budget": context_char_budget,
|
||||||
"llm_temperature": llm_temperature,
|
"llm_temperature": llm_temperature,
|
||||||
@@ -542,7 +542,7 @@ def main():
|
|||||||
result = asyncio.run(
|
result = asyncio.run(
|
||||||
run_memsciqa_test(
|
run_memsciqa_test(
|
||||||
sample_size=sample_size,
|
sample_size=sample_size,
|
||||||
group_id=args.group_id,
|
end_user_id=args.end_user_id,
|
||||||
search_limit=args.search_limit,
|
search_limit=args.search_limit,
|
||||||
context_char_budget=args.context_char_budget,
|
context_char_budget=args.context_char_budget,
|
||||||
llm_temperature=args.llm_temperature,
|
llm_temperature=args.llm_temperature,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ except Exception:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
from app.core.memory.utils.config.definitions import SELECTED_end_user_id, PROJECT_ROOT
|
||||||
|
|
||||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||||
@@ -26,7 +26,7 @@ async def run(
|
|||||||
dataset: str,
|
dataset: str,
|
||||||
sample_size: int,
|
sample_size: int,
|
||||||
reset_group: bool,
|
reset_group: bool,
|
||||||
group_id: str | None,
|
end_user_id: str | None,
|
||||||
judge_model: str | None = None,
|
judge_model: str | None = None,
|
||||||
search_limit: int | None = None,
|
search_limit: int | None = None,
|
||||||
context_char_budget: int | None = None,
|
context_char_budget: int | None = None,
|
||||||
@@ -37,17 +37,17 @@ async def run(
|
|||||||
max_contexts_per_item: int | None = None,
|
max_contexts_per_item: int | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
end_user_id = end_user_id or SELECTED_end_user_id
|
||||||
|
|
||||||
if reset_group:
|
if reset_group:
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
await connector.delete_group(group_id)
|
await connector.delete_group(end_user_id)
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
if dataset == "locomo":
|
if dataset == "locomo":
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||||
if search_limit is not None:
|
if search_limit is not None:
|
||||||
kwargs["search_limit"] = search_limit
|
kwargs["search_limit"] = search_limit
|
||||||
if context_char_budget is not None:
|
if context_char_budget is not None:
|
||||||
@@ -61,7 +61,7 @@ async def run(
|
|||||||
return await run_locomo_eval(**kwargs)
|
return await run_locomo_eval(**kwargs)
|
||||||
|
|
||||||
if dataset == "memsciqa":
|
if dataset == "memsciqa":
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||||
if search_limit is not None:
|
if search_limit is not None:
|
||||||
kwargs["search_limit"] = search_limit
|
kwargs["search_limit"] = search_limit
|
||||||
if context_char_budget is not None:
|
if context_char_budget is not None:
|
||||||
@@ -75,7 +75,7 @@ async def run(
|
|||||||
return await run_memsciqa_eval(**kwargs)
|
return await run_memsciqa_eval(**kwargs)
|
||||||
|
|
||||||
if dataset == "longmemeval":
|
if dataset == "longmemeval":
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||||
if search_limit is not None:
|
if search_limit is not None:
|
||||||
kwargs["search_limit"] = search_limit
|
kwargs["search_limit"] = search_limit
|
||||||
if context_char_budget is not None:
|
if context_char_budget is not None:
|
||||||
@@ -99,8 +99,8 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
|
||||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||||
@@ -117,7 +117,7 @@ def main():
|
|||||||
args.dataset,
|
args.dataset,
|
||||||
args.sample_size,
|
args.sample_size,
|
||||||
args.reset_group,
|
args.reset_group,
|
||||||
args.group_id,
|
args.end_user_id,
|
||||||
args.judge_model,
|
args.judge_model,
|
||||||
args.search_limit,
|
args.search_limit,
|
||||||
args.context_char_budget,
|
args.context_char_budget,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
# Fix tokenizer parallelism warning
|
# Fix tokenizer parallelism warning
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
|
|||||||
try:
|
try:
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
except Exception:
|
except Exception:
|
||||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
|
||||||
OpenAIClient = Any
|
OpenAIClient = Any
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMChunker:
|
class LLMChunker:
|
||||||
"""基于LLM的智能分块策略"""
|
"""LLM-based intelligent chunking strategy"""
|
||||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
async def __call__(self, text: str) -> List[Any]:
|
async def __call__(self, text: str) -> List[Any]:
|
||||||
# 使用LLM分析文本结构并进行智能分块
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
|
||||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
Return results in JSON format with a chunks array, each chunk having a text field.
|
||||||
|
|
||||||
文本内容:
|
Text content:
|
||||||
{text[:5000]}
|
{text[:5000]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -171,8 +173,6 @@ class ChunkerClient:
|
|||||||
base_chunk_size=self.chunk_size,
|
base_chunk_size=self.chunk_size,
|
||||||
)
|
)
|
||||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
|
||||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
|
||||||
self.chunker = SentenceChunker(
|
self.chunker = SentenceChunker(
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
chunk_overlap=self.chunk_overlap,
|
chunk_overlap=self.chunk_overlap,
|
||||||
@@ -186,100 +186,93 @@ class ChunkerClient:
|
|||||||
|
|
||||||
async def generate_chunks(self, dialogue: DialogData):
|
async def generate_chunks(self, dialogue: DialogData):
|
||||||
"""
|
"""
|
||||||
生成分块,支持异步操作
|
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||||
|
|
||||||
|
Each message creates one chunk, directly inheriting role information.
|
||||||
|
If a message is too long, it will be split into multiple sub-chunks,
|
||||||
|
each maintaining the same speaker.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If dialogue has no messages or chunking fails
|
||||||
"""
|
"""
|
||||||
|
# Validate dialogue has messages
|
||||||
|
if not dialogue.context or not dialogue.context.msgs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||||
|
f"Cannot generate chunks from empty dialogue."
|
||||||
|
)
|
||||||
|
|
||||||
|
dialogue.chunks = []
|
||||||
|
|
||||||
|
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||||
|
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||||
|
# Validate message has required attributes
|
||||||
|
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
|
||||||
|
raise ValueError(
|
||||||
|
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||||
|
f"missing 'role' or 'msg' attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_content = msg.msg.strip()
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not msg_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果消息太长,可以进一步分块
|
||||||
|
if len(msg_content) > self.chunk_size:
|
||||||
|
# 对单个消息的内容进行分块
|
||||||
try:
|
try:
|
||||||
# 预处理文本:确保对话标记格式统一
|
sub_chunks = self.chunker(msg_content)
|
||||||
content = dialogue.content
|
except Exception as e:
|
||||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
raise ValueError(
|
||||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
for idx, sub_chunk in enumerate(sub_chunks):
|
||||||
# 同步分块器
|
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||||
chunks = self.chunker(content)
|
sub_chunk_text = sub_chunk_text.strip()
|
||||||
else:
|
|
||||||
# 异步分块器(如LLMChunker)
|
|
||||||
chunks = await self.chunker(content)
|
|
||||||
|
|
||||||
# 过滤空块和过小的块
|
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||||
valid_chunks = []
|
continue
|
||||||
for c in chunks:
|
|
||||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
|
||||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
|
||||||
valid_chunks.append(c)
|
|
||||||
|
|
||||||
dialogue.chunks = [
|
chunk = Chunk(
|
||||||
Chunk(
|
content=f"{msg.role}: {sub_chunk_text}",
|
||||||
content=c.text if hasattr(c, 'text') else str(c),
|
speaker=msg.role, # 直接继承角色
|
||||||
metadata={
|
metadata={
|
||||||
"start_index": getattr(c, "start_index", None),
|
"message_index": msg_idx,
|
||||||
"end_index": getattr(c, "end_index", None),
|
"message_role": msg.role,
|
||||||
|
"sub_chunk_index": idx,
|
||||||
|
"total_sub_chunks": len(sub_chunks),
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for c in valid_chunks
|
dialogue.chunks.append(chunk)
|
||||||
]
|
|
||||||
return dialogue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"分块失败: {e}")
|
|
||||||
|
|
||||||
# 改进的后备方案:尝试按对话回合分割
|
|
||||||
try:
|
|
||||||
# 简单的按对话分割
|
|
||||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
|
||||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
|
||||||
|
|
||||||
class SimpleChunk:
|
|
||||||
def __init__(self, text, start_index, end_index):
|
|
||||||
self.text = text
|
|
||||||
self.start_index = start_index
|
|
||||||
self.end_index = end_index
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
current_chunk = ""
|
|
||||||
current_start = 0
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
speaker, ct = match[0], match[1].strip()
|
|
||||||
turn_text = f"{speaker} {ct}"
|
|
||||||
|
|
||||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
|
||||||
current_chunk = turn_text
|
|
||||||
current_start = dialogue.content.find(turn_text, current_start)
|
|
||||||
else:
|
else:
|
||||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
# 消息不长,直接作为一个 chunk
|
||||||
|
chunk = Chunk(
|
||||||
if current_chunk:
|
content=f"{msg.role}: {msg_content}",
|
||||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
speaker=msg.role, # 直接继承角色
|
||||||
|
|
||||||
dialogue.chunks = [
|
|
||||||
Chunk(
|
|
||||||
content=c.text,
|
|
||||||
metadata={
|
metadata={
|
||||||
"start_index": c.start_index,
|
"message_index": msg_idx,
|
||||||
"end_index": c.end_index,
|
"message_role": msg.role,
|
||||||
"chunker_strategy": "DialogueTurnFallback",
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for c in chunks
|
dialogue.chunks.append(chunk)
|
||||||
]
|
|
||||||
|
|
||||||
except Exception:
|
# Validate we generated at least one chunk
|
||||||
# 最后的手段:单一大块
|
if not dialogue.chunks:
|
||||||
dialogue.chunks = [Chunk(
|
raise ValueError(
|
||||||
content=dialogue.content,
|
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
|
||||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
f"All messages were either empty or too short. "
|
||||||
)]
|
f"Messages count: {len(dialogue.context.msgs)}"
|
||||||
|
)
|
||||||
|
|
||||||
return dialogue
|
return dialogue
|
||||||
|
|
||||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||||
"""
|
"""Evaluate chunking quality."""
|
||||||
评估分块质量
|
|
||||||
"""
|
|
||||||
if not getattr(dialogue, 'chunks', None):
|
if not getattr(dialogue, 'chunks', None):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -304,11 +297,8 @@ class ChunkerClient:
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||||
"""
|
"""Save chunking results to file with strategy name in filename."""
|
||||||
保存分块结果到文件,文件名包含策略名称
|
|
||||||
"""
|
|
||||||
strategy_name = self.chunker_config.chunker_strategy
|
strategy_name = self.chunker_config.chunker_strategy
|
||||||
# 在文件名中添加策略名称
|
|
||||||
base_name, ext = os.path.splitext(output_path)
|
base_name, ext = os.path.splitext(output_path)
|
||||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||||
|
|
||||||
|
|||||||
@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
|
|||||||
config["callbacks"] = [self.langfuse_handler]
|
config["callbacks"] = [self.langfuse_handler]
|
||||||
|
|
||||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||||
|
|
||||||
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"使用 PydanticOutputParser 解析成功")
|
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
|
||||||
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 方法 2: 使用 LangChain 的 with_structured_output
|
# 方法 2: 使用 LangChain 的 with_structured_output
|
||||||
template = """{question}"""
|
template = """{question}"""
|
||||||
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
|
|||||||
|
|
||||||
# 验证并返回结果
|
# 验证并返回结果
|
||||||
try:
|
try:
|
||||||
return response_model.model_validate(parsed)
|
result = response_model.model_validate(parsed)
|
||||||
|
return result
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果已经是 Pydantic 实例,直接返回
|
# 如果已经是 Pydantic 实例,直接返回
|
||||||
if hasattr(parsed, "model_dump"):
|
if hasattr(parsed, "model_dump"):
|
||||||
return parsed
|
return parsed
|
||||||
# 尝试从 JSON 解析
|
# 尝试从 JSON 解析
|
||||||
return response_model.model_validate_json(json.dumps(parsed))
|
result = response_model.model_validate_json(json.dumps(parsed))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.warning("with_structured_output 方法不可用")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"结构化输出失败: {e}")
|
logger.error(f"结构化输出失败: {e}")
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
|||||||
"""Parameters for temporal search queries in the knowledge graph.
|
"""Parameters for temporal search queries in the knowledge graph.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
group_id: Group ID to filter search results (default: 'test')
|
end_user_id: Group ID to filter search results (default: 'test')
|
||||||
apply_id: Application ID to filter search results
|
apply_id: Application ID to filter search results
|
||||||
user_id: User ID to filter search results
|
user_id: User ID to filter search results
|
||||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
|||||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||||
limit: Maximum number of results to return (default: 3)
|
limit: Maximum number of results to return (default: 3)
|
||||||
"""
|
"""
|
||||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||||
|
|||||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
|||||||
id: Unique identifier for the edge
|
id: Unique identifier for the edge
|
||||||
source: ID of the source node
|
source: ID of the source node
|
||||||
target: ID of the target node
|
target: ID of the target node
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
run_id: Unique identifier for the pipeline run that created this edge
|
run_id: Unique identifier for the pipeline run that created this edge
|
||||||
created_at: Timestamp when the edge was created (system perspective)
|
created_at: Timestamp when the edge was created (system perspective)
|
||||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
|||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||||
source: str = Field(..., description="The ID of the source node.")
|
source: str = Field(..., description="The ID of the source node.")
|
||||||
target: str = Field(..., description="The ID of the target node.")
|
target: str = Field(..., description="The ID of the target node.")
|
||||||
group_id: str = Field(..., description="The group ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
user_id: str = Field(..., description="The user ID of the edge.")
|
|
||||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
|||||||
Attributes:
|
Attributes:
|
||||||
id: Unique identifier for the node
|
id: Unique identifier for the node
|
||||||
name: Name of the node
|
name: Name of the node
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
run_id: Unique identifier for the pipeline run that created this node
|
run_id: Unique identifier for the pipeline run that created this node
|
||||||
created_at: Timestamp when the node was created (system perspective)
|
created_at: Timestamp when the node was created (system perspective)
|
||||||
expired_at: Optional timestamp when the node expires (system perspective)
|
expired_at: Optional timestamp when the node expires (system perspective)
|
||||||
"""
|
"""
|
||||||
id: str = Field(..., description="The unique identifier for the node.")
|
id: str = Field(..., description="The unique identifier for the node.")
|
||||||
name: str = Field(..., description="The name of the node.")
|
name: str = Field(..., description="The name of the node.")
|
||||||
group_id: str = Field(..., description="The group ID of the node.")
|
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||||
user_id: str = Field(..., description="The user ID of the edge.")
|
|
||||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||||
@@ -224,6 +216,7 @@ class StatementNode(Node):
|
|||||||
chunk_id: ID of the parent chunk this statement belongs to
|
chunk_id: ID of the parent chunk this statement belongs to
|
||||||
stmt_type: Type of the statement (from ontology)
|
stmt_type: Type of the statement (from ontology)
|
||||||
statement: The actual statement text content
|
statement: The actual statement text content
|
||||||
|
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
|
||||||
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
||||||
emotion_target: Optional emotion target (person or object name)
|
emotion_target: Optional emotion target (person or object name)
|
||||||
emotion_subject: Optional emotion subject (self/other/object)
|
emotion_subject: Optional emotion subject (self/other/object)
|
||||||
@@ -249,6 +242,12 @@ class StatementNode(Node):
|
|||||||
stmt_type: str = Field(..., description="Type of the statement")
|
stmt_type: str = Field(..., description="Type of the statement")
|
||||||
statement: str = Field(..., description="The statement text content")
|
statement: str = Field(..., description="The statement text content")
|
||||||
|
|
||||||
|
# Speaker identification
|
||||||
|
speaker: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||||
|
)
|
||||||
|
|
||||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||||
emotion_intensity: Optional[float] = Field(
|
emotion_intensity: Optional[float] = Field(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
|
|||||||
"""Represents a single message in a conversation.
|
"""Represents a single message in a conversation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
|
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
|
||||||
msg: Text content of the message
|
msg: Text content of the message
|
||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
|
||||||
|
|
||||||
@@ -55,8 +55,9 @@ class Statement(BaseModel):
|
|||||||
Attributes:
|
Attributes:
|
||||||
id: Unique identifier for the statement
|
id: Unique identifier for the statement
|
||||||
chunk_id: ID of the parent chunk this statement belongs to
|
chunk_id: ID of the parent chunk this statement belongs to
|
||||||
group_id: Optional group ID for multi-tenancy
|
end_user_id: Optional group ID for multi-tenancy
|
||||||
statement: The actual statement text content
|
statement: The actual statement text content
|
||||||
|
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||||
statement_embedding: Optional embedding vector for the statement
|
statement_embedding: Optional embedding vector for the statement
|
||||||
stmt_type: Type of the statement (from ontology)
|
stmt_type: Type of the statement (from ontology)
|
||||||
temporal_info: Temporal information extracted from the statement
|
temporal_info: Temporal information extracted from the statement
|
||||||
@@ -72,8 +73,9 @@ class Statement(BaseModel):
|
|||||||
"""
|
"""
|
||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||||
statement: str = Field(..., description="The text content of the statement.")
|
statement: str = Field(..., description="The text content of the statement.")
|
||||||
|
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||||
stmt_type: StatementType = Field(..., description="The type of the statement.")
|
stmt_type: StatementType = Field(..., description="The type of the statement.")
|
||||||
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
|
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
|
||||||
@@ -118,35 +120,35 @@ class Chunk(BaseModel):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
id: Unique identifier for the chunk
|
id: Unique identifier for the chunk
|
||||||
text: List of messages in the chunk
|
|
||||||
content: The content of the chunk as a formatted string
|
content: The content of the chunk as a formatted string
|
||||||
|
speaker: The speaker/role for this chunk (user/assistant)
|
||||||
statements: List of statements extracted from this chunk
|
statements: List of statements extracted from this chunk
|
||||||
chunk_embedding: Optional embedding vector for the chunk
|
chunk_embedding: Optional embedding vector for the chunk
|
||||||
metadata: Additional metadata as key-value pairs
|
metadata: Additional metadata as key-value pairs
|
||||||
"""
|
"""
|
||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
|
||||||
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
|
|
||||||
content: str = Field(..., description="The content of the chunk as a string.")
|
content: str = Field(..., description="The content of the chunk as a string.")
|
||||||
|
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
|
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
|
||||||
"""Create a chunk from a list of messages.
|
"""Create a chunk from a single message (1 Message = 1 Chunk).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages
|
message: Single conversation message
|
||||||
metadata: Optional metadata dictionary
|
metadata: Optional metadata dictionary
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Chunk instance with formatted content
|
Chunk instance with speaker directly from message.role
|
||||||
"""
|
"""
|
||||||
if metadata is None:
|
return cls(
|
||||||
metadata = {}
|
content=f"{message.role}: {message.msg}",
|
||||||
# Generate content from messages
|
speaker=message.role,
|
||||||
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
|
metadata=metadata or {}
|
||||||
return cls(text=messages, content=content, metadata=metadata)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DialogData(BaseModel):
|
class DialogData(BaseModel):
|
||||||
@@ -157,9 +159,7 @@ class DialogData(BaseModel):
|
|||||||
context: Full conversation context
|
context: Full conversation context
|
||||||
dialog_embedding: Optional embedding vector for the entire dialog
|
dialog_embedding: Optional embedding vector for the entire dialog
|
||||||
ref_id: Reference ID linking to external dialog system
|
ref_id: Reference ID linking to external dialog system
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
created_at: Timestamp when the dialog was created
|
created_at: Timestamp when the dialog was created
|
||||||
expired_at: Timestamp when the dialog expires (default: far future)
|
expired_at: Timestamp when the dialog expires (default: far future)
|
||||||
metadata: Additional metadata as key-value pairs
|
metadata: Additional metadata as key-value pairs
|
||||||
@@ -173,9 +173,7 @@ class DialogData(BaseModel):
|
|||||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
|
||||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||||
@@ -254,5 +252,5 @@ class DialogData(BaseModel):
|
|||||||
"""
|
"""
|
||||||
for chunk in self.chunks:
|
for chunk in self.chunks:
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
if statement.group_id is None:
|
if statement.end_user_id is None:
|
||||||
statement.group_id = self.group_id
|
statement.end_user_id = self.end_user_id
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
@@ -396,13 +397,13 @@ def rerank_with_activation(
|
|||||||
return reranked
|
return reranked
|
||||||
|
|
||||||
|
|
||||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||||
"""Log search query information using the logger.
|
"""Log search query information using the logger.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: The search query text
|
query_text: The search query text
|
||||||
search_type: Type of search (keyword, embedding, hybrid)
|
search_type: Type of search (keyword, embedding, hybrid)
|
||||||
group_id: Group identifier for filtering
|
end_user_id: Group identifier for filtering
|
||||||
limit: Maximum number of results
|
limit: Maximum number of results
|
||||||
include: List of result types to include
|
include: List of result types to include
|
||||||
log_file: Deprecated parameter, kept for backward compatibility
|
log_file: Deprecated parameter, kept for backward compatibility
|
||||||
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
|
|||||||
# Log using the standard logger
|
# Log using the standard logger
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||||
f"group_id={group_id}, limit={limit}, include={include}"
|
f"end_user_id={end_user_id}, limit={limit}, include={include}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
|
|||||||
async def run_hybrid_search(
|
async def run_hybrid_search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
group_id: str | None,
|
end_user_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
include: List[str],
|
include: List[str],
|
||||||
output_path: str | None,
|
output_path: str | None,
|
||||||
@@ -692,6 +693,9 @@ async def run_hybrid_search(
|
|||||||
# Start overall timing
|
# Start overall timing
|
||||||
search_start_time = time.time()
|
search_start_time = time.time()
|
||||||
latency_metrics = {}
|
latency_metrics = {}
|
||||||
|
print(100*'-')
|
||||||
|
print(memory_config)
|
||||||
|
print(100 * '-')
|
||||||
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
||||||
|
|
||||||
# Clean and normalize the incoming query before use/logging
|
# Clean and normalize the incoming query before use/logging
|
||||||
@@ -715,7 +719,7 @@ async def run_hybrid_search(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Log the search query
|
# Log the search query
|
||||||
log_search_query(query_text, search_type, group_id, limit, include)
|
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
results = {}
|
results = {}
|
||||||
@@ -732,7 +736,7 @@ async def run_hybrid_search(
|
|||||||
search_graph(
|
search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=query_text,
|
q=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include
|
include=include
|
||||||
)
|
)
|
||||||
@@ -769,7 +773,7 @@ async def run_hybrid_search(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
)
|
)
|
||||||
@@ -916,9 +920,7 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
|
|
||||||
async def search_by_temporal(
|
async def search_by_temporal(
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -929,7 +931,7 @@ async def search_by_temporal(
|
|||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
|
|
||||||
- Matches statements created between start_date and end_date
|
- Matches statements created between start_date and end_date
|
||||||
- Optionally filters by group_id
|
- Optionally filters by end_user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
@@ -939,9 +941,7 @@ async def search_by_temporal(
|
|||||||
end_date = normalize_date_safe(end_date)
|
end_date = normalize_date_safe(end_date)
|
||||||
|
|
||||||
params = TemporalSearchParams.model_validate({
|
params = TemporalSearchParams.model_validate({
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"apply_id": apply_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"start_date": start_date,
|
"start_date": start_date,
|
||||||
"end_date": end_date,
|
"end_date": end_date,
|
||||||
"valid_date": valid_date,
|
"valid_date": valid_date,
|
||||||
@@ -950,9 +950,7 @@ async def search_by_temporal(
|
|||||||
})
|
})
|
||||||
statements = await search_graph_by_temporal(
|
statements = await search_graph_by_temporal(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
group_id=params.group_id,
|
end_user_id=params.end_user_id,
|
||||||
apply_id=params.apply_id,
|
|
||||||
user_id=params.user_id,
|
|
||||||
start_date=params.start_date,
|
start_date=params.start_date,
|
||||||
end_date=params.end_date,
|
end_date=params.end_date,
|
||||||
valid_date=params.valid_date,
|
valid_date=params.valid_date,
|
||||||
@@ -964,9 +962,7 @@ async def search_by_temporal(
|
|||||||
|
|
||||||
async def search_by_keyword_temporal(
|
async def search_by_keyword_temporal(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -987,9 +983,7 @@ async def search_by_keyword_temporal(
|
|||||||
invalid_date = normalize_date_safe(invalid_date)
|
invalid_date = normalize_date_safe(invalid_date)
|
||||||
|
|
||||||
params = TemporalSearchParams.model_validate({
|
params = TemporalSearchParams.model_validate({
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"apply_id": apply_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"start_date": start_date,
|
"start_date": start_date,
|
||||||
"end_date": end_date,
|
"end_date": end_date,
|
||||||
"valid_date": valid_date,
|
"valid_date": valid_date,
|
||||||
@@ -999,9 +993,7 @@ async def search_by_keyword_temporal(
|
|||||||
statements = await search_graph_by_keyword_temporal(
|
statements = await search_graph_by_keyword_temporal(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=params.group_id,
|
end_user_id=params.end_user_id,
|
||||||
apply_id=params.apply_id,
|
|
||||||
user_id=params.user_id,
|
|
||||||
start_date=params.start_date,
|
start_date=params.start_date,
|
||||||
end_date=params.end_date,
|
end_date=params.end_date,
|
||||||
valid_date=params.valid_date,
|
valid_date=params.valid_date,
|
||||||
@@ -1013,7 +1005,7 @@ async def search_by_keyword_temporal(
|
|||||||
|
|
||||||
async def search_chunk_by_chunk_id(
|
async def search_chunk_by_chunk_id(
|
||||||
chunk_id: str,
|
chunk_id: str,
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1023,8 +1015,68 @@ async def search_chunk_by_chunk_id(
|
|||||||
chunks = await search_graph_by_chunk_id(
|
chunks = await search_graph_by_chunk_id(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
chunk_id=chunk_id,
|
chunk_id=chunk_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
return {"chunks": chunks}
|
return {"chunks": chunks}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试混合检索功能
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
from app.db import get_db
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
# 从数据库获取真实配置
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
|
||||||
|
# 使用 config_id=17 获取配置
|
||||||
|
memory_config = config_service.load_memory_config(config_id=17)
|
||||||
|
|
||||||
|
if not memory_config:
|
||||||
|
print("错误:找不到 config_id=17 的配置")
|
||||||
|
print("请先在数据库中创建配置,或修改 config_id")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"✓ 成功加载配置: {memory_config.config_name}")
|
||||||
|
print(f" - Workspace: {memory_config.workspace_name}")
|
||||||
|
print(f" - LLM Model: {memory_config.llm_model_name}")
|
||||||
|
print(f" - Embedding Model: {memory_config.embedding_model_name}")
|
||||||
|
print(f" - Storage Type: {memory_config.storage_type}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 修改这里的参数进行测试
|
||||||
|
test_end_user_id = "021886bc-fab9-4fd5-b607-497b262e0381" # 修改为你的 end_user_id
|
||||||
|
test_query = "小明擅长什么?" # 修改为你的查询
|
||||||
|
|
||||||
|
print(f"开始测试检索...")
|
||||||
|
print(f" - Query: {test_query}")
|
||||||
|
print(f" - End User ID: {test_end_user_id}")
|
||||||
|
print(f" - Search Type: hybrid")
|
||||||
|
print()
|
||||||
|
|
||||||
|
results = asyncio.run(run_hybrid_search(
|
||||||
|
query_text=test_query,
|
||||||
|
search_type="hybrid", # 可选: "keyword", "embedding", "hybrid"
|
||||||
|
end_user_id=test_end_user_id,
|
||||||
|
limit=10,
|
||||||
|
include=["statements", "entities", "chunks", "summaries"],
|
||||||
|
output_path=None,
|
||||||
|
memory_config=memory_config,
|
||||||
|
rerank_alpha=0.6,
|
||||||
|
use_forgetting_rerank=False,
|
||||||
|
use_llm_rerank=False
|
||||||
|
))
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("检索结果:")
|
||||||
|
print("=" * 80)
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|||||||
@@ -555,8 +555,8 @@ class DataPreprocessor:
|
|||||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||||
|
|
||||||
|
|
||||||
# 获取group_id,如果不存在则生成默认值
|
# 获取end_user_id,如果不存在则生成默认值
|
||||||
group_id = item.get('group_id', f'group_default_{i}')
|
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||||
user_id = item.get('user_id', f'user_default_{i}')
|
user_id = item.get('user_id', f'user_default_{i}')
|
||||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||||
|
|
||||||
@@ -574,7 +574,7 @@ class DataPreprocessor:
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=context,
|
context=context,
|
||||||
ref_id=dialog_id,
|
ref_id=dialog_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
metadata=metadata
|
metadata=metadata
|
||||||
@@ -644,7 +644,7 @@ class DataPreprocessor:
|
|||||||
|
|
||||||
context = ConversationContext(msgs=messages)
|
context = ConversationContext(msgs=messages)
|
||||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||||
group_id = item.get('group_id', f'group_default_{i}')
|
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||||
user_id = item.get('user_id', f'user_default_{i}')
|
user_id = item.get('user_id', f'user_default_{i}')
|
||||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||||
|
|
||||||
@@ -657,7 +657,7 @@ class DataPreprocessor:
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=context,
|
context=context,
|
||||||
ref_id=dialog_id,
|
ref_id=dialog_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
metadata=metadata
|
metadata=metadata
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ def accurate_match(
|
|||||||
entity_nodes: List[ExtractedEntityNode]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||||
"""
|
"""
|
||||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||||
"""
|
"""
|
||||||
exact_merge_map: Dict[str, Dict] = {}
|
exact_merge_map: Dict[str, Dict] = {}
|
||||||
@@ -210,8 +210,8 @@ def accurate_match(
|
|||||||
for ent in entity_nodes:
|
for ent in entity_nodes:
|
||||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
|
||||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
|
||||||
if key not in canonical_map:
|
if key not in canonical_map:
|
||||||
canonical_map[key] = ent
|
canonical_map[key] = ent
|
||||||
id_redirect[ent.id] = ent.id
|
id_redirect[ent.id] = ent.id
|
||||||
@@ -223,11 +223,11 @@ def accurate_match(
|
|||||||
id_redirect[ent.id] = canonical.id
|
id_redirect[ent.id] = canonical.id
|
||||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||||
try:
|
try:
|
||||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||||
if k not in exact_merge_map:
|
if k not in exact_merge_map:
|
||||||
exact_merge_map[k] = {
|
exact_merge_map[k] = {
|
||||||
"canonical_id": canonical.id,
|
"canonical_id": canonical.id,
|
||||||
"group_id": canonical.group_id,
|
"end_user_id": canonical.end_user_id,
|
||||||
"name": canonical.name,
|
"name": canonical.name,
|
||||||
"entity_type": canonical.entity_type,
|
"entity_type": canonical.entity_type,
|
||||||
"merged_ids": set(),
|
"merged_ids": set(),
|
||||||
@@ -596,7 +596,7 @@ def fuzzy_match(
|
|||||||
b = deduped_entities[j]
|
b = deduped_entities[j]
|
||||||
|
|
||||||
# 跳过不同业务组的实体
|
# 跳过不同业务组的实体
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -671,7 +671,7 @@ def fuzzy_match(
|
|||||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||||
fuzzy_merge_records.append(
|
fuzzy_merge_records.append(
|
||||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
|
||||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
|||||||
# 记录 LLM 融合日志
|
# 记录 LLM 融合日志
|
||||||
try:
|
try:
|
||||||
llm_records.append(
|
llm_records.append(
|
||||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||||
)
|
)
|
||||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
|
|||||||
id_redirect[k] = a.id
|
id_redirect[k] = a.id
|
||||||
try:
|
try:
|
||||||
disamb_records.append(
|
disamb_records.append(
|
||||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ async def _judge_pair(
|
|||||||
pass
|
pass
|
||||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||||
ctx = {
|
ctx = {
|
||||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"name_text_sim": name_text_sim,
|
"name_text_sim": name_text_sim,
|
||||||
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
ctx = {
|
ctx = {
|
||||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"name_text_sim": name_text_sim,
|
"name_text_sim": name_text_sim,
|
||||||
"name_embed_sim": name_embed_sim,
|
"name_embed_sim": name_embed_sim,
|
||||||
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
|||||||
a = entity_nodes[i]
|
a = entity_nodes[i]
|
||||||
for j in range(i + 1, len(entity_nodes)):
|
for j in range(i + 1, len(entity_nodes)):
|
||||||
b = entity_nodes[j]
|
b = entity_nodes[j]
|
||||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
# 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复)
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
continue
|
continue
|
||||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||||
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
- max_rounds: upper bound for iterative passes (default 3)
|
- max_rounds: upper bound for iterative passes (default 3)
|
||||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||||
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
|
|
||||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||||
"""
|
"""
|
||||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: 实体节点列表
|
nodes: 实体节点列表
|
||||||
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
"""
|
"""
|
||||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||||
for e in nodes:
|
for e in nodes:
|
||||||
gid = getattr(e, "group_id", None)
|
gid = getattr(e, "end_user_id", None)
|
||||||
groups.setdefault(str(gid), []).append(e)
|
groups.setdefault(str(gid), []).append(e)
|
||||||
blocks: List[List[ExtractedEntityNode]] = []
|
blocks: List[List[ExtractedEntityNode]] = []
|
||||||
for gid, arr in groups.items():
|
for gid, arr in groups.items():
|
||||||
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||||
current_nodes = _collapse_nodes(current_nodes)
|
current_nodes = _collapse_nodes(current_nodes)
|
||||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
# 步骤2:分块(按end_user_id分块,避免跨组处理)
|
||||||
blocks = _partition_blocks(current_nodes)
|
blocks = _partition_blocks(current_nodes)
|
||||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||||
break
|
break
|
||||||
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
|
|||||||
a = entity_nodes[i]
|
a = entity_nodes[i]
|
||||||
b = entity_nodes[j]
|
b = entity_nodes[j]
|
||||||
# 必须同组
|
# 必须同组
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
continue
|
continue
|
||||||
ta = getattr(a, "entity_type", None)
|
ta = getattr(a, "entity_type", None)
|
||||||
tb = getattr(b, "entity_type", None)
|
tb = getattr(b, "entity_type", None)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
return ExtractedEntityNode(
|
return ExtractedEntityNode(
|
||||||
id=row.get("id"),
|
id=row.get("id"),
|
||||||
name=row.get("name") or "",
|
name=row.get("name") or "",
|
||||||
group_id=row.get("group_id") or "",
|
end_user_id=row.get("end_user_id") or "",
|
||||||
user_id=row.get("user_id") or "",
|
user_id=row.get("user_id") or "",
|
||||||
apply_id=row.get("apply_id") or "",
|
apply_id=row.get("apply_id") or "",
|
||||||
created_at=_parse_dt(row.get("created_at")),
|
created_at=_parse_dt(row.get("created_at")),
|
||||||
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
|
|
||||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||||
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
|||||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||||
"""
|
"""
|
||||||
第二层去重消歧:
|
第二层去重消歧:
|
||||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
|
||||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
|||||||
|
|
||||||
]
|
]
|
||||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||||
connector=connector, group_id=group_id,
|
connector=connector, end_user_id=end_user_id,
|
||||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
if pipeline_config is None:
|
if pipeline_config is None:
|
||||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||||
|
|
||||||
# 先探测 group_id,决定报告写入策略
|
# 先探测 end_user_id,决定报告写入策略
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
for dd in dialog_data_list:
|
for dd in dialog_data_list:
|
||||||
group_id = getattr(dd, "group_id", None)
|
end_user_id = getattr(dd, "end_user_id", None)
|
||||||
if group_id:
|
if end_user_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 第一层去重消歧
|
# 第一层去重消歧
|
||||||
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
|
|
||||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||||
try:
|
try:
|
||||||
if group_id:
|
if end_user_id:
|
||||||
if connector:
|
if connector:
|
||||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
entity_nodes=dedup_entity_nodes,
|
entity_nodes=dedup_entity_nodes,
|
||||||
statement_entity_edges=dedup_statement_entity_edges,
|
statement_entity_edges=dedup_statement_entity_edges,
|
||||||
entity_entity_edges=dedup_entity_entity_edges,
|
entity_entity_edges=dedup_entity_entity_edges,
|
||||||
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
else:
|
else:
|
||||||
print("Skip second-layer dedup: missing connector")
|
print("Skip second-layer dedup: missing connector")
|
||||||
else:
|
else:
|
||||||
print("Skip second-layer dedup: missing group_id")
|
print("Skip second-layer dedup: missing end_user_id")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Second-layer dedup failed: {e}")
|
print(f"Second-layer dedup failed: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
|
|||||||
for d_idx, dialog in enumerate(dialog_data_list):
|
for d_idx, dialog in enumerate(dialog_data_list):
|
||||||
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
||||||
for c_idx, chunk in enumerate(dialog.chunks):
|
for c_idx, chunk in enumerate(dialog.chunks):
|
||||||
all_chunks.append((chunk, dialog.group_id, dialogue_content))
|
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
|
||||||
chunk_metadata.append((d_idx, c_idx))
|
chunk_metadata.append((d_idx, c_idx))
|
||||||
|
|
||||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||||
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
|
|||||||
# 全局并行处理所有分块
|
# 全局并行处理所有分块
|
||||||
async def extract_for_chunk(chunk_data, chunk_index):
|
async def extract_for_chunk(chunk_data, chunk_index):
|
||||||
nonlocal completed_chunks
|
nonlocal completed_chunks
|
||||||
chunk, group_id, dialogue_content = chunk_data
|
chunk, end_user_id, dialogue_content = chunk_data
|
||||||
try:
|
try:
|
||||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
|
||||||
|
|
||||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||||
@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
|
|||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dialog_data_list: 对话数据列表
|
dialog_data_list: 对话数据列表
|
||||||
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
|
|||||||
Returns:
|
Returns:
|
||||||
情绪信息映射列表,每个对话对应一个字典
|
情绪信息映射列表,每个对话对应一个字典
|
||||||
"""
|
"""
|
||||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
logger.info("开始情绪信息提取(仅处理用户消息)")
|
||||||
|
|
||||||
# 收集所有陈述句及其配置
|
# 收集所有陈述句及其配置
|
||||||
all_statements = []
|
all_statements = []
|
||||||
@@ -598,14 +598,21 @@ class ExtractionOrchestrator:
|
|||||||
logger.info("情绪提取未启用,跳过")
|
logger.info("情绪提取未启用,跳过")
|
||||||
return [{} for _ in dialog_data_list]
|
return [{} for _ in dialog_data_list]
|
||||||
|
|
||||||
# 收集所有陈述句
|
# 收集所有陈述句(只收集 speaker 为 "user" 的)
|
||||||
|
total_statements = 0
|
||||||
|
filtered_statements = 0
|
||||||
|
|
||||||
for d_idx, dialog in enumerate(dialog_data_list):
|
for d_idx, dialog in enumerate(dialog_data_list):
|
||||||
for chunk in dialog.chunks:
|
for chunk in dialog.chunks:
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
|
total_statements += 1
|
||||||
|
# 只处理用户的陈述句 (role 为 "user")
|
||||||
|
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||||
all_statements.append((statement, data_config))
|
all_statements.append((statement, data_config))
|
||||||
statement_metadata.append((d_idx, statement.id))
|
statement_metadata.append((d_idx, statement.id))
|
||||||
|
filtered_statements += 1
|
||||||
|
|
||||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
|
||||||
|
|
||||||
# 初始化情绪提取服务
|
# 初始化情绪提取服务
|
||||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||||
@@ -985,9 +992,7 @@ class ExtractionOrchestrator:
|
|||||||
id=dialog_data.id,
|
id=dialog_data.id,
|
||||||
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
||||||
ref_id=dialog_data.ref_id,
|
ref_id=dialog_data.ref_id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
content=dialog_data.context.content if dialog_data.context else "",
|
content=dialog_data.context.content if dialog_data.context else "",
|
||||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||||
@@ -1005,9 +1010,7 @@ class ExtractionOrchestrator:
|
|||||||
id=chunk.id,
|
id=chunk.id,
|
||||||
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
||||||
dialog_id=dialog_data.id,
|
dialog_id=dialog_data.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
content=chunk.content,
|
content=chunk.content,
|
||||||
chunk_embedding=chunk.chunk_embedding,
|
chunk_embedding=chunk.chunk_embedding,
|
||||||
@@ -1028,11 +1031,10 @@ class ExtractionOrchestrator:
|
|||||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
|
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||||
statement_embedding=statement.statement_embedding,
|
statement_embedding=statement.statement_embedding,
|
||||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||||
@@ -1052,9 +1054,7 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edge = StatementChunkEdge(
|
statement_chunk_edge = StatementChunkEdge(
|
||||||
source=statement.id,
|
source=statement.id,
|
||||||
target=chunk.id,
|
target=chunk.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
)
|
)
|
||||||
@@ -1087,9 +1087,7 @@ class ExtractionOrchestrator:
|
|||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
@@ -1104,9 +1102,7 @@ class ExtractionOrchestrator:
|
|||||||
source=statement.id,
|
source=statement.id,
|
||||||
target=entity.id,
|
target=entity.id,
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
)
|
)
|
||||||
@@ -1126,9 +1122,7 @@ class ExtractionOrchestrator:
|
|||||||
relation_type=triplet.predicate,
|
relation_type=triplet.predicate,
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
source_statement_id=statement.id,
|
source_statement_id=statement.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
@@ -1755,14 +1749,14 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从测试数据生成分块对话
|
"""从测试数据生成分块对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
indices: 要处理的数据索引列表(可选)
|
indices: 要处理的数据索引列表(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1826,7 +1820,7 @@ async def get_chunked_dialogs(
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=data['id'],
|
ref_id=data['id'],
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
metadata=dialog_metadata,
|
metadata=dialog_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1928,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
|||||||
|
|
||||||
async def get_chunked_dialogs_with_preprocessing(
|
async def get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "default",
|
end_user_id: str = "default",
|
||||||
user_id: str = "default",
|
user_id: str = "default",
|
||||||
apply_id: str = "default",
|
apply_id: str = "default",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
@@ -1940,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: 分块策略
|
chunker_strategy: 分块策略
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
apply_id: 应用ID
|
apply_id: 应用ID
|
||||||
indices: 要处理的数据索引列表
|
indices: 要处理的数据索引列表
|
||||||
@@ -1968,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
indices=indices,
|
indices=indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置 group_id, user_id, apply_id
|
# 设置 end_user_id
|
||||||
for dd in preprocessed_data:
|
for dd in preprocessed_data:
|
||||||
dd.group_id = group_id
|
dd.end_user_id = end_user_id
|
||||||
dd.user_id = user_id
|
|
||||||
dd.apply_id = apply_id
|
|
||||||
|
|
||||||
# 步骤2: 语义剪枝
|
# 步骤2: 语义剪枝
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ class DialogueChunker:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||||
"""
|
"""
|
||||||
self.chunker_strategy = chunker_strategy
|
self.chunker_strategy = chunker_strategy
|
||||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||||
# 对于 LLMChunker,需要传入 llm_client
|
|
||||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||||
else:
|
else:
|
||||||
@@ -41,29 +41,19 @@ class DialogueChunker:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of Chunk objects
|
A list of Chunk objects
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If chunking fails or returns empty chunks
|
||||||
"""
|
"""
|
||||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||||
# Defensive fallback: ensure at least one chunk is returned for non-empty content
|
|
||||||
try:
|
|
||||||
chunks = result_dialogue.chunks
|
chunks = result_dialogue.chunks
|
||||||
except Exception:
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
if not chunks or len(chunks) == 0:
|
if not chunks or len(chunks) == 0:
|
||||||
# If the dialogue has content, return a single fallback chunk built from messages
|
raise ValueError(
|
||||||
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
|
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||||
if content_str and len(content_str.strip()) > 0:
|
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||||
fallback_chunk = Chunk.from_messages(
|
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||||
dialogue.context.msgs,
|
|
||||||
metadata={
|
|
||||||
"fallback": "single_chunk",
|
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
|
||||||
"source": "DialogueChunkerFallback",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return [fallback_chunk]
|
|
||||||
# No content: return empty list
|
|
||||||
return []
|
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
@@ -72,22 +62,25 @@ class DialogueChunker:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dialogue: The processed DialogData object with chunks
|
dialogue: The processed DialogData object with chunks
|
||||||
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
|
output_path: Optional path to save the output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The path where the output was saved
|
The path where the output was saved
|
||||||
"""
|
"""
|
||||||
if not output_path:
|
if not output_path:
|
||||||
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
|
output_path = os.path.join(
|
||||||
f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
os.path.dirname(__file__), "..", "..",
|
||||||
|
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||||
|
)
|
||||||
|
|
||||||
output_lines = []
|
output_lines = [
|
||||||
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
|
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||||
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
|
f"Dialogue ID: {dialogue.ref_id}",
|
||||||
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
|
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||||
output_lines.append(f"Total characters: {len(dialogue.content)}")
|
f"Total characters: {len(dialogue.content)}",
|
||||||
|
f"Generated {len(dialogue.chunks)} chunks:"
|
||||||
|
]
|
||||||
|
|
||||||
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
|
|
||||||
for i, chunk in enumerate(dialogue.chunks):
|
for i, chunk in enumerate(dialogue.chunks):
|
||||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||||
output_lines.append(f" Content preview: {chunk.content}...")
|
output_lines.append(f" Content preview: {chunk.content}...")
|
||||||
|
|||||||
@@ -193,9 +193,9 @@ async def _process_chunk_summary(
|
|||||||
node = MemorySummaryNode(
|
node = MemorySummaryNode(
|
||||||
id=uuid4().hex,
|
id=uuid4().hex,
|
||||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||||
group_id=dialog.group_id,
|
end_user_id=dialog.end_user_id,
|
||||||
user_id=dialog.user_id,
|
user_id=dialog.end_user_id,
|
||||||
apply_id=dialog.apply_id,
|
apply_id=dialog.end_user_id,
|
||||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
expired_at=datetime(9999, 12, 31),
|
expired_at=datetime(9999, 12, 31),
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ from datetime import datetime
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from app.core.memory.models.message_models import DialogData, Statement
|
from app.core.memory.models.message_models import DialogData, Statement
|
||||||
|
|
||||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
|
||||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||||
from app.core.memory.utils.data.ontology import (
|
from app.core.memory.utils.data.ontology import (
|
||||||
LABEL_DEFINITIONS,
|
LABEL_DEFINITIONS,
|
||||||
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class ExtractedStatement(BaseModel):
|
class ExtractedStatement(BaseModel):
|
||||||
"""Schema for extracted statement from LLM"""
|
"""Schema for extracted statement from LLM"""
|
||||||
statement: str = Field(..., description="The extracted statement text")
|
statement: str = Field(..., description="The extracted statement text")
|
||||||
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
|
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||||
|
|
||||||
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
|
|
||||||
class StatementExtractionResponse(BaseModel):
|
class StatementExtractionResponse(BaseModel):
|
||||||
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
|
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
|
||||||
|
|
||||||
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
class StatementExtractor:
|
class StatementExtractor:
|
||||||
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
|
"""Class for extracting statements from dialog chunks using LLM"""
|
||||||
|
|
||||||
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
|
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
|
||||||
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
|
||||||
"""Initialize the StatementExtractor with an LLM client and configuration
|
"""Initialize the StatementExtractor with an LLM client and configuration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -71,21 +67,38 @@ class StatementExtractor:
|
|||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
self.config = config or StatementExtractionConfig()
|
self.config = config or StatementExtractionConfig()
|
||||||
|
|
||||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
|
||||||
|
"""Get speaker directly from Chunk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: Chunk object containing speaker field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Speaker role ("user"/"assistant") or None if cannot be determined
|
||||||
|
"""
|
||||||
|
if hasattr(chunk, 'speaker') and chunk.speaker:
|
||||||
|
return chunk.speaker
|
||||||
|
|
||||||
|
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||||
"""Process a single chunk and return extracted statements
|
"""Process a single chunk and return extracted statements
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk: Chunk object to process
|
chunk: Chunk object to process
|
||||||
group_id: Group ID to assign to all statements in this chunk
|
end_user_id: Group ID to assign to all statements in this chunk
|
||||||
dialogue_content: Full dialogue content to provide as context
|
dialogue_content: Full dialogue content to provide as context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of ExtractedStatement objects extracted from the chunk
|
List of ExtractedStatement objects extracted from the chunk
|
||||||
"""
|
"""
|
||||||
# Prepare the chunk content for processing
|
|
||||||
chunk_content = chunk.content
|
chunk_content = chunk.content
|
||||||
|
|
||||||
# Render the prompt using helper function
|
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||||
|
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||||
|
return []
|
||||||
|
|
||||||
prompt_content = await render_statement_extraction_prompt(
|
prompt_content = await render_statement_extraction_prompt(
|
||||||
chunk_content=chunk_content,
|
chunk_content=chunk_content,
|
||||||
definitions=LABEL_DEFINITIONS,
|
definitions=LABEL_DEFINITIONS,
|
||||||
@@ -137,14 +150,18 @@ class StatementExtractor:
|
|||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
relevence_info = RelevenceInfo.RELEVANT
|
relevence_info = RelevenceInfo.RELEVANT
|
||||||
|
|
||||||
|
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||||
|
|
||||||
chunk_statement = Statement(
|
chunk_statement = Statement(
|
||||||
statement=extracted_stmt.statement,
|
statement=extracted_stmt.statement,
|
||||||
stmt_type=stmt_type,
|
stmt_type=stmt_type,
|
||||||
temporal_info=temporal_type,
|
temporal_info=temporal_type,
|
||||||
relevence_info=relevence_info,
|
relevence_info=relevence_info,
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
|
speaker=chunk_speaker,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_statements.append(chunk_statement)
|
chunk_statements.append(chunk_statement)
|
||||||
|
|
||||||
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
|
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
|
||||||
@@ -167,10 +184,10 @@ class StatementExtractor:
|
|||||||
|
|
||||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||||
|
|
||||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
|
||||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
|
||||||
return_exceptions=True
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,7 +225,7 @@ class StatementExtractor:
|
|||||||
for i, statement in enumerate(statements, 1):
|
for i, statement in enumerate(statements, 1):
|
||||||
f.write(f"Statement {i}:\n")
|
f.write(f"Statement {i}:\n")
|
||||||
f.write(f"Id: {statement.id}\n")
|
f.write(f"Id: {statement.id}\n")
|
||||||
f.write(f"Group Id: {statement.group_id}\n")
|
f.write(f"Group Id: {statement.end_user_id}\n")
|
||||||
f.write(f"Content: {statement.statement}\n")
|
f.write(f"Content: {statement.statement}\n")
|
||||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||||
@@ -226,12 +243,7 @@ class StatementExtractor:
|
|||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
|
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
|
||||||
"""按对话分组聚合强/弱关系并写入 TXT 文件。
|
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
|
||||||
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
|
|
||||||
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
|
|
||||||
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
|
|
||||||
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
|
|
||||||
"""
|
|
||||||
print("\n=== Relations Classify ===")
|
print("\n=== Relations Classify ===")
|
||||||
|
|
||||||
# 使用全局配置的输出路径
|
# 使用全局配置的输出路径
|
||||||
@@ -286,7 +298,7 @@ class StatementExtractor:
|
|||||||
|
|
||||||
dialog_sections.append({
|
dialog_sections.append({
|
||||||
"dialog_id": dialog.ref_id,
|
"dialog_id": dialog.ref_id,
|
||||||
"group_id": dialog.group_id,
|
"end_user_id": dialog.end_user_id,
|
||||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||||
"strong": strong_relations,
|
"strong": strong_relations,
|
||||||
"weak": weak_relations,
|
"weak": weak_relations,
|
||||||
@@ -300,7 +312,7 @@ class StatementExtractor:
|
|||||||
for idx, section in enumerate(dialog_sections, 1):
|
for idx, section in enumerate(dialog_sections, 1):
|
||||||
f.write(f"Dialog {idx}:\n")
|
f.write(f"Dialog {idx}:\n")
|
||||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
|
||||||
f.write("Content:\n")
|
f.write("Content:\n")
|
||||||
f.write(f"{section.get('content', '')}\n")
|
f.write(f"{section.get('content', '')}\n")
|
||||||
f.write("-" * 40 + "\n\n")
|
f.write("-" * 40 + "\n\n")
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class TemporalExtractor:
|
|||||||
prompt_logger.info("")
|
prompt_logger.info("")
|
||||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||||
prompt_logger.info(
|
prompt_logger.info(
|
||||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class TripletExtractor:
|
|||||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||||
try:
|
try:
|
||||||
prompt_logger.info(
|
prompt_logger.info(
|
||||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_time: Optional[datetime] = None
|
current_time: Optional[datetime] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -91,7 +91,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||||
group_id: 组ID(可选,用于过滤)
|
end_user_id: 组ID(可选,用于过滤)
|
||||||
current_time: 当前时间(可选,默认使用系统时间)
|
current_time: 当前时间(可选,默认使用系统时间)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -123,7 +123,7 @@ class AccessHistoryManager:
|
|||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
# 步骤1:读取当前节点状态
|
# 步骤1:读取当前节点状态
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
|
|
||||||
if not node_data:
|
if not node_data:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -142,7 +142,7 @@ class AccessHistoryManager:
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
update_data=update_data,
|
update_data=update_data,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -172,7 +172,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_ids: List[str],
|
node_ids: List[str],
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_time: Optional[datetime] = None
|
current_time: Optional[datetime] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -184,7 +184,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_ids: 节点ID列表
|
node_ids: 节点ID列表
|
||||||
node_label: 节点标签(所有节点必须是同一类型)
|
node_label: 节点标签(所有节点必须是同一类型)
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
current_time: 当前时间(可选)
|
current_time: 当前时间(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -202,7 +202,7 @@ class AccessHistoryManager:
|
|||||||
task = self.record_access(
|
task = self.record_access(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
current_time=current_time
|
current_time=current_time
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
@@ -235,7 +235,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
检查节点数据的一致性
|
检查节点数据的一致性
|
||||||
@@ -249,14 +249,14 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||||
- 一致性检查结果枚举
|
- 一致性检查结果枚举
|
||||||
- 错误描述(如果不一致)
|
- 错误描述(如果不一致)
|
||||||
"""
|
"""
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
|
|
||||||
if not node_data:
|
if not node_data:
|
||||||
return ConsistencyCheckResult.CONSISTENT, None
|
return ConsistencyCheckResult.CONSISTENT, None
|
||||||
@@ -305,7 +305,7 @@ class AccessHistoryManager:
|
|||||||
async def check_batch_consistency(
|
async def check_batch_consistency(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1000
|
limit: int = 1000
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -313,7 +313,7 @@ class AccessHistoryManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
limit: 检查的最大节点数
|
limit: 检查的最大节点数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -329,16 +329,16 @@ class AccessHistoryManager:
|
|||||||
MATCH (n:{node_label})
|
MATCH (n:{node_label})
|
||||||
WHERE n.access_history IS NOT NULL
|
WHERE n.access_history IS NOT NULL
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {"limit": limit}
|
params = {"limit": limit}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params["group_id"] = group_id
|
params["end_user_id"] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
node_ids = [r['id'] for r in results]
|
node_ids = [r['id'] for r in results]
|
||||||
@@ -351,7 +351,7 @@ class AccessHistoryManager:
|
|||||||
result, message = await self.check_consistency(
|
result, message = await self.check_consistency(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == ConsistencyCheckResult.CONSISTENT:
|
if result == ConsistencyCheckResult.CONSISTENT:
|
||||||
@@ -387,7 +387,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
自动修复节点的数据不一致问题
|
自动修复节点的数据不一致问题
|
||||||
@@ -401,7 +401,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 修复成功返回True,否则返回False
|
bool: 修复成功返回True,否则返回False
|
||||||
@@ -411,7 +411,7 @@ class AccessHistoryManager:
|
|||||||
result, message = await self.check_consistency(
|
result, message = await self.check_consistency(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == ConsistencyCheckResult.CONSISTENT:
|
if result == ConsistencyCheckResult.CONSISTENT:
|
||||||
@@ -419,7 +419,7 @@ class AccessHistoryManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 获取节点数据
|
# 获取节点数据
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
if not node_data:
|
if not node_data:
|
||||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||||
return False
|
return False
|
||||||
@@ -457,8 +457,8 @@ class AccessHistoryManager:
|
|||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " WHERE n.group_id = $group_id"
|
query += " WHERE n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
SET n += $repair_data
|
SET n += $repair_data
|
||||||
RETURN n
|
RETURN n
|
||||||
@@ -468,8 +468,8 @@ class AccessHistoryManager:
|
|||||||
'node_id': node_id,
|
'node_id': node_id,
|
||||||
'repair_data': repair_data
|
'repair_data': repair_data
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
await self.connector.execute_query(query, **params)
|
await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -491,7 +491,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取节点数据
|
获取节点数据
|
||||||
@@ -499,7 +499,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||||
@@ -507,8 +507,8 @@ class AccessHistoryManager:
|
|||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " WHERE n.group_id = $group_id"
|
query += " WHERE n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
RETURN n.id as id,
|
RETURN n.id as id,
|
||||||
n.importance_score as importance_score,
|
n.importance_score as importance_score,
|
||||||
@@ -519,8 +519,8 @@ class AccessHistoryManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {'node_id': node_id}
|
params = {'node_id': node_id}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -585,7 +585,7 @@ class AccessHistoryManager:
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
update_data: Dict[str, Any],
|
update_data: Dict[str, Any],
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
原子性更新节点(使用乐观锁)
|
原子性更新节点(使用乐观锁)
|
||||||
@@ -597,7 +597,7 @@ class AccessHistoryManager:
|
|||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
update_data: 更新数据
|
update_data: 更新数据
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 更新后的节点数据
|
Dict[str, Any]: 更新后的节点数据
|
||||||
@@ -606,13 +606,13 @@ class AccessHistoryManager:
|
|||||||
RuntimeError: 如果更新失败或发生版本冲突
|
RuntimeError: 如果更新失败或发生版本冲突
|
||||||
"""
|
"""
|
||||||
# 定义事务函数
|
# 定义事务函数
|
||||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||||
# 步骤1:读取当前节点并获取版本号
|
# 步骤1:读取当前节点并获取版本号
|
||||||
read_query = f"""
|
read_query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
read_query += " WHERE n.group_id = $group_id"
|
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||||
read_query += """
|
read_query += """
|
||||||
RETURN n.id as id,
|
RETURN n.id as id,
|
||||||
n.version as version,
|
n.version as version,
|
||||||
@@ -624,8 +624,8 @@ class AccessHistoryManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
read_params = {'node_id': node_id}
|
read_params = {'node_id': node_id}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
read_params['group_id'] = group_id
|
read_params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
read_result = await tx.run(read_query, **read_params)
|
read_result = await tx.run(read_query, **read_params)
|
||||||
current_node = await read_result.single()
|
current_node = await read_result.single()
|
||||||
@@ -656,8 +656,8 @@ class AccessHistoryManager:
|
|||||||
|
|
||||||
# 构建 WHERE 子句
|
# 构建 WHERE 子句
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
if group_id:
|
if end_user_id:
|
||||||
where_conditions.append("n.group_id = $group_id")
|
where_conditions.append("n.end_user_id = $end_user_id")
|
||||||
|
|
||||||
# 添加版本检查
|
# 添加版本检查
|
||||||
if current_version > 0:
|
if current_version > 0:
|
||||||
@@ -695,8 +695,8 @@ class AccessHistoryManager:
|
|||||||
'last_access_time': update_data['last_access_time'],
|
'last_access_time': update_data['last_access_time'],
|
||||||
'access_count': update_data['access_count']
|
'access_count': update_data['access_count']
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
update_params['group_id'] = group_id
|
update_params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
update_result = await tx.run(update_query, **update_params)
|
update_result = await tx.run(update_query, **update_params)
|
||||||
updated_node = await update_result.single()
|
updated_node = await update_result.single()
|
||||||
@@ -720,7 +720,7 @@ class AccessHistoryManager:
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
update_data=update_data,
|
update_data=update_data,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
async def run_forgetting_cycle(
|
async def run_forgetting_cycle(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
max_merge_batch_size: int = 100,
|
max_merge_batch_size: int = 100,
|
||||||
min_days_since_access: int = 30,
|
min_days_since_access: int = 30,
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[int] = None,
|
||||||
@@ -77,7 +77,7 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||||
config_id: 配置ID(可选,用于获取 llm_id)
|
config_id: 配置ID(可选,用于获取 llm_id)
|
||||||
@@ -107,19 +107,19 @@ class ForgettingScheduler:
|
|||||||
start_time_iso = start_time.isoformat()
|
start_time_iso = start_time.isoformat()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"开始遗忘周期: group_id={group_id}, "
|
f"开始遗忘周期: end_user_id={end_user_id}, "
|
||||||
f"max_batch={max_merge_batch_size}, "
|
f"max_batch={max_merge_batch_size}, "
|
||||||
f"min_days={min_days_since_access}"
|
f"min_days={min_days_since_access}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 步骤1:统计遗忘前的节点数量
|
# 步骤1:统计遗忘前的节点数量
|
||||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
nodes_before = await self._count_knowledge_nodes(end_user_id)
|
||||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||||
|
|
||||||
# 步骤2:识别可遗忘的节点对
|
# 步骤2:识别可遗忘的节点对
|
||||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
min_days_since_access=min_days_since_access
|
min_days_since_access=min_days_since_access
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ class ForgettingScheduler:
|
|||||||
'statement_text': pair['statement_text'],
|
'statement_text': pair['statement_text'],
|
||||||
'statement_activation': pair['statement_activation'],
|
'statement_activation': pair['statement_activation'],
|
||||||
'statement_importance': pair['statement_importance'],
|
'statement_importance': pair['statement_importance'],
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
entity_node = {
|
entity_node = {
|
||||||
@@ -222,7 +222,7 @@ class ForgettingScheduler:
|
|||||||
'entity_type': pair['entity_type'],
|
'entity_type': pair['entity_type'],
|
||||||
'entity_activation': pair['entity_activation'],
|
'entity_activation': pair['entity_activation'],
|
||||||
'entity_importance': pair['entity_importance'],
|
'entity_importance': pair['entity_importance'],
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# 融合节点
|
# 融合节点
|
||||||
@@ -262,7 +262,7 @@ class ForgettingScheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 步骤6:统计遗忘后的节点数量
|
# 步骤6:统计遗忘后的节点数量
|
||||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
nodes_after = await self._count_knowledge_nodes(end_user_id)
|
||||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||||
|
|
||||||
# 步骤7:生成遗忘报告
|
# 步骤7:生成遗忘报告
|
||||||
@@ -315,7 +315,7 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
async def _count_knowledge_nodes(
|
async def _count_knowledge_nodes(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
统计知识层节点总数
|
统计知识层节点总数
|
||||||
@@ -323,7 +323,7 @@ class ForgettingScheduler:
|
|||||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 知识层节点总数
|
int: 知识层节点总数
|
||||||
@@ -333,16 +333,16 @@ class ForgettingScheduler:
|
|||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
RETURN count(n) as total
|
RETURN count(n) as total
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {}
|
params = {}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
end_user_id['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class ForgettingStrategy:
|
|||||||
|
|
||||||
async def find_forgettable_nodes(
|
async def find_forgettable_nodes(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
min_days_since_access: int = 30
|
min_days_since_access: int = 30
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +102,7 @@ class ForgettingStrategy:
|
|||||||
3. Statement 和 Entity 之间存在关系边
|
3. Statement 和 Entity 之间存在关系边
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -136,8 +136,8 @@ class ForgettingStrategy:
|
|||||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
RETURN s.id as statement_id,
|
RETURN s.id as statement_id,
|
||||||
@@ -159,8 +159,8 @@ class ForgettingStrategy:
|
|||||||
'threshold': self.forgetting_threshold,
|
'threshold': self.forgetting_threshold,
|
||||||
'cutoff_time': cutoff_time_iso
|
'cutoff_time': cutoff_time_iso
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -247,8 +247,8 @@ class ForgettingStrategy:
|
|||||||
entity_activation = entity_node['entity_activation']
|
entity_activation = entity_node['entity_activation']
|
||||||
entity_importance = entity_node['entity_importance']
|
entity_importance = entity_node['entity_importance']
|
||||||
|
|
||||||
# 获取 group_id(从 statement 或 entity 节点)
|
# 获取 end_user_id(从 statement 或 entity 节点)
|
||||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
|
||||||
|
|
||||||
# 生成摘要内容
|
# 生成摘要内容
|
||||||
summary_text = await self._generate_summary(
|
summary_text = await self._generate_summary(
|
||||||
@@ -325,7 +325,7 @@ class ForgettingStrategy:
|
|||||||
last_access_time: $current_time,
|
last_access_time: $current_time,
|
||||||
access_count: 1,
|
access_count: 1,
|
||||||
version: 1,
|
version: 1,
|
||||||
group_id: $group_id,
|
end_user_id: $end_user_id,
|
||||||
created_at: datetime($current_time),
|
created_at: datetime($current_time),
|
||||||
merged_at: datetime($current_time)
|
merged_at: datetime($current_time)
|
||||||
})
|
})
|
||||||
@@ -423,7 +423,7 @@ class ForgettingStrategy:
|
|||||||
'inherited_activation': inherited_activation,
|
'inherited_activation': inherited_activation,
|
||||||
'inherited_importance': inherited_importance,
|
'inherited_importance': inherited_importance,
|
||||||
'current_time': current_time_iso,
|
'current_time': current_time_iso,
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ __all__ = [
|
|||||||
async def run_hybrid_search(
|
async def run_hybrid_search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
group_id: str | None = None,
|
end_user_id: str | None = None,
|
||||||
apply_id: str | None = None,
|
apply_id: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
@@ -54,7 +54,7 @@ async def run_hybrid_search(
|
|||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||||
group_id: 组ID过滤
|
end_user_id: 组ID过滤
|
||||||
apply_id: 应用ID过滤
|
apply_id: 应用ID过滤
|
||||||
user_id: 用户ID过滤
|
user_id: 用户ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
@@ -104,7 +104,7 @@ async def run_hybrid_search(
|
|||||||
# 执行搜索
|
# 执行搜索
|
||||||
result = await strategy.search(
|
result = await strategy.search(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
|
|||||||
@@ -77,7 +77,7 @@
|
|||||||
# async def search(
|
# async def search(
|
||||||
# self,
|
# self,
|
||||||
# query_text: str,
|
# query_text: str,
|
||||||
# group_id: Optional[str] = None,
|
# end_user_id: Optional[str] = None,
|
||||||
# limit: int = 50,
|
# limit: int = 50,
|
||||||
# include: Optional[List[str]] = None,
|
# include: Optional[List[str]] = None,
|
||||||
# **kwargs
|
# **kwargs
|
||||||
@@ -86,7 +86,7 @@
|
|||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# query_text: 查询文本
|
# query_text: 查询文本
|
||||||
# group_id: 可选的组ID过滤
|
# end_user_id: 可选的组ID过滤
|
||||||
# limit: 每个类别的最大结果数
|
# limit: 每个类别的最大结果数
|
||||||
# include: 要包含的搜索类别列表
|
# include: 要包含的搜索类别列表
|
||||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||||
@@ -94,7 +94,7 @@
|
|||||||
# Returns:
|
# Returns:
|
||||||
# SearchResult: 搜索结果对象
|
# SearchResult: 搜索结果对象
|
||||||
# """
|
# """
|
||||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# # 从kwargs中获取参数
|
# # 从kwargs中获取参数
|
||||||
# alpha = kwargs.get("alpha", self.alpha)
|
# alpha = kwargs.get("alpha", self.alpha)
|
||||||
@@ -107,14 +107,14 @@
|
|||||||
# # 并行执行关键词搜索和语义搜索
|
# # 并行执行关键词搜索和语义搜索
|
||||||
# keyword_result = await self.keyword_strategy.search(
|
# keyword_result = await self.keyword_strategy.search(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list
|
# include=include_list
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# semantic_result = await self.semantic_strategy.search(
|
# semantic_result = await self.semantic_strategy.search(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list
|
# include=include_list
|
||||||
# )
|
# )
|
||||||
@@ -139,7 +139,7 @@
|
|||||||
# metadata = self._create_metadata(
|
# metadata = self._create_metadata(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# search_type="hybrid",
|
# search_type="hybrid",
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list,
|
# include=include_list,
|
||||||
# alpha=alpha,
|
# alpha=alpha,
|
||||||
@@ -165,7 +165,7 @@
|
|||||||
# metadata=self._create_metadata(
|
# metadata=self._create_metadata(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# search_type="hybrid",
|
# search_type="hybrid",
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# error=str(e)
|
# error=str(e)
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表
|
include: 要包含的搜索类别列表
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
SearchResult: 搜索结果对象
|
SearchResult: 搜索结果对象
|
||||||
"""
|
"""
|
||||||
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# 获取有效的搜索类别
|
# 获取有效的搜索类别
|
||||||
include_list = self._get_include_list(include)
|
include_list = self._get_include_list(include)
|
||||||
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
results_dict = await search_graph(
|
results_dict = await search_graph(
|
||||||
connector=self.connector,
|
connector=self.connector,
|
||||||
q=query_text,
|
q=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
metadata = self._create_metadata(
|
metadata = self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="keyword",
|
search_type="keyword",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
metadata=self._create_metadata(
|
metadata=self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="keyword",
|
search_type="keyword",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
|
|||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
|
|||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
search_type: 搜索类型
|
search_type: 搜索类型
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
limit: 结果限制
|
limit: 结果限制
|
||||||
**kwargs: 其他元数据
|
**kwargs: 其他元数据
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"query": query_text,
|
"query": query_text,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表
|
include: 要包含的搜索类别列表
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
SearchResult: 搜索结果对象
|
SearchResult: 搜索结果对象
|
||||||
"""
|
"""
|
||||||
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# 获取有效的搜索类别
|
# 获取有效的搜索类别
|
||||||
include_list = self._get_include_list(include)
|
include_list = self._get_include_list(include)
|
||||||
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
connector=self.connector,
|
connector=self.connector,
|
||||||
embedder_client=self.embedder_client,
|
embedder_client=self.embedder_client,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
metadata = self._create_metadata(
|
metadata = self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="semantic",
|
search_type="semantic",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
metadata=self._create_metadata(
|
metadata=self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="semantic",
|
search_type="semantic",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
|||||||
target_keys = [
|
target_keys = [
|
||||||
"id",
|
"id",
|
||||||
"statement",
|
"statement",
|
||||||
"group_id",
|
"end_user_id",
|
||||||
"chunk_id",
|
"chunk_id",
|
||||||
"created_at",
|
"created_at",
|
||||||
"expired_at",
|
"expired_at",
|
||||||
@@ -75,7 +75,7 @@ async def get_data(result):
|
|||||||
"""
|
"""
|
||||||
EXCLUDE_FIELDS = {
|
EXCLUDE_FIELDS = {
|
||||||
"user_id",
|
"user_id",
|
||||||
"group_id",
|
"end_user_id",
|
||||||
"entity_type",
|
"entity_type",
|
||||||
"connect_strength",
|
"connect_strength",
|
||||||
"relationship_type",
|
"relationship_type",
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class ConfigAuditLogger:
|
|||||||
self,
|
self,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
details: Optional[Dict[str, Any]] = None
|
details: Optional[Dict[str, Any]] = None
|
||||||
):
|
):
|
||||||
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
|
|||||||
Args:
|
Args:
|
||||||
config_id: 配置 ID
|
config_id: 配置 ID
|
||||||
user_id: 用户 ID(可选)
|
user_id: 用户 ID(可选)
|
||||||
group_id: 组 ID(可选)
|
end_user_id: 组 ID(可选)
|
||||||
success: 是否成功
|
success: 是否成功
|
||||||
details: 详细信息(可选)
|
details: 详细信息(可选)
|
||||||
"""
|
"""
|
||||||
result = "SUCCESS" if success else "FAILED"
|
result = "SUCCESS" if success else "FAILED"
|
||||||
msg = (
|
msg = (
|
||||||
f"CONFIG_LOAD config_id={config_id} "
|
f"CONFIG_LOAD config_id={config_id} "
|
||||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
|
||||||
f"result={result}"
|
f"result={result}"
|
||||||
)
|
)
|
||||||
if details:
|
if details:
|
||||||
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
|
|||||||
self,
|
self,
|
||||||
operation: str,
|
operation: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
|
|||||||
Args:
|
Args:
|
||||||
operation: 操作类型(WRITE, READ 等)
|
operation: 操作类型(WRITE, READ 等)
|
||||||
config_id: 配置 ID
|
config_id: 配置 ID
|
||||||
group_id: 组 ID
|
end_user_id: 组 ID
|
||||||
success: 是否成功
|
success: 是否成功
|
||||||
duration: 操作耗时(秒)
|
duration: 操作耗时(秒)
|
||||||
error: 错误信息(可选)
|
error: 错误信息(可选)
|
||||||
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
|
|||||||
result = "SUCCESS" if success else "FAILED"
|
result = "SUCCESS" if success else "FAILED"
|
||||||
msg = (
|
msg = (
|
||||||
f"{operation.upper()} config_id={config_id} "
|
f"{operation.upper()} config_id={config_id} "
|
||||||
f"group={group_id} result={result}"
|
f"group={end_user_id} result={result}"
|
||||||
)
|
)
|
||||||
if duration is not None:
|
if duration is not None:
|
||||||
msg += f" duration={duration:.2f}s"
|
msg += f" duration={duration:.2f}s"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from enum import StrEnum, auto
|
|||||||
class Field(StrEnum):
|
class Field(StrEnum):
|
||||||
CONTENT_KEY = "page_content"
|
CONTENT_KEY = "page_content"
|
||||||
METADATA_KEY = "metadata"
|
METADATA_KEY = "metadata"
|
||||||
GROUP_KEY = "group_id"
|
GROUP_KEY = "end_user_id"
|
||||||
VECTOR = auto()
|
VECTOR = auto()
|
||||||
# Sparse Vector aims to support full text search
|
# Sparse Vector aims to support full text search
|
||||||
SPARSE_VECTOR = auto()
|
SPARSE_VECTOR = auto()
|
||||||
|
|||||||
@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First check if model exists at all (without tenant filtering)
|
# OPTIMIZED: Single query with tenant filter
|
||||||
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
# We'll check tenant mismatch in the error handling
|
||||||
|
|
||||||
# Then check with tenant filtering
|
|
||||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
|
# Model not found with tenant filter - check if it exists without filter
|
||||||
|
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
||||||
|
|
||||||
if model_without_tenant:
|
if model_without_tenant:
|
||||||
# Model exists but belongs to different tenant
|
# Model exists but belongs to different tenant
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -208,8 +209,11 @@ def validate_embedding_model(
|
|||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> UUID:
|
) -> tuple[UUID, str]:
|
||||||
"""Validate that embedding model is available and return its UUID.
|
"""Validate that embedding model is available and return its UUID and name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (embedding_uuid, embedding_name)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
InvalidConfigError: If embedding_id is not provided or invalid
|
InvalidConfigError: If embedding_id is not provided or invalid
|
||||||
@@ -225,14 +229,19 @@ def validate_embedding_model(
|
|||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_uuid, _ = validate_and_resolve_model_id(
|
embedding_uuid, embedding_name = validate_and_resolve_model_id(
|
||||||
embedding_id, "embedding", db, tenant_id, required=True,
|
embedding_id, "embedding", db, tenant_id, required=True,
|
||||||
config_id=config_id, workspace_id=workspace_id
|
config_id=config_id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
print(100*'-')
|
|
||||||
print(embedding_uuid)
|
logger.debug(
|
||||||
print(_)
|
"Embedding model validated",
|
||||||
print(100*'-')
|
extra={
|
||||||
|
"embedding_uuid": str(embedding_uuid),
|
||||||
|
"embedding_name": embedding_name,
|
||||||
|
"config_id": config_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if embedding_uuid is None:
|
if embedding_uuid is None:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
@@ -243,7 +252,7 @@ def validate_embedding_model(
|
|||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return embedding_uuid
|
return embedding_uuid, embedding_name
|
||||||
|
|
||||||
|
|
||||||
def validate_llm_model(
|
def validate_llm_model(
|
||||||
|
|||||||
@@ -104,38 +104,6 @@ class DataConfigRepository:
|
|||||||
r.statement AS statement
|
r.statement AS statement
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Entity graph within group (source node, edge, target node)
|
|
||||||
SEARCH_FOR_ENTITY_GRAPH = """
|
|
||||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
|
||||||
WHERE n.group_id = $group_id
|
|
||||||
RETURN
|
|
||||||
{
|
|
||||||
entity_idx: n.entity_idx,
|
|
||||||
connect_strength: n.connect_strength,
|
|
||||||
description: n.description,
|
|
||||||
entity_type: n.entity_type,
|
|
||||||
name: n.name,
|
|
||||||
fact_summary: COALESCE(n.fact_summary, ''),
|
|
||||||
id: n.id
|
|
||||||
} AS sourceNode,
|
|
||||||
{
|
|
||||||
rel_id: elementId(r),
|
|
||||||
source_id: startNode(r).id,
|
|
||||||
target_id: endNode(r).id,
|
|
||||||
predicate: r.predicate,
|
|
||||||
statement_id: r.statement_id,
|
|
||||||
statement: r.statement
|
|
||||||
} AS edge,
|
|
||||||
{
|
|
||||||
entity_idx: m.entity_idx,
|
|
||||||
connect_strength: m.connect_strength,
|
|
||||||
description: m.description,
|
|
||||||
entity_type: m.entity_type,
|
|
||||||
name: m.name,
|
|
||||||
fact_summary: COALESCE(m.fact_summary, ''),
|
|
||||||
id: m.id
|
|
||||||
} AS targetNode
|
|
||||||
"""
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_reflection_config(
|
def update_reflection_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]
|
|||||||
end_user = repo.get_end_user_by_id(end_user_id)
|
end_user = repo.get_end_user_by_id(end_user_id)
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
def update_end_user_other_name(
|
|
||||||
db: Session,
|
|
||||||
end_user_id: uuid.UUID,
|
|
||||||
other_name: str
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
通过 end_user_id 更新 end_user 表中的 other_name 字段
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
end_user_id: 宿主ID
|
|
||||||
other_name: 要更新的用户名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 更新的记录数
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 执行更新
|
|
||||||
updated_count = (
|
|
||||||
db.query(EndUser)
|
|
||||||
.filter(EndUser.id == end_user_id)
|
|
||||||
.update(
|
|
||||||
{EndUser.other_name: other_name},
|
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}")
|
|
||||||
return updated_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 新增的缓存操作函数(保持与类方法一致的接口)
|
# 新增的缓存操作函数(保持与类方法一致的接口)
|
||||||
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||||
"""根据ID获取终端用户(用于缓存操作)"""
|
"""根据ID获取终端用户(用于缓存操作)"""
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
|
|||||||
"id": stable_edge_id,
|
"id": stable_edge_id,
|
||||||
"source": chunk.id,
|
"source": chunk.id,
|
||||||
"target": stmt.id,
|
"target": stmt.id,
|
||||||
"group_id": getattr(stmt, 'group_id', None),
|
"end_user_id": getattr(stmt, 'end_user_id', None),
|
||||||
"user_id":getattr(stmt, 'user_id', None),
|
"user_id":getattr(stmt, 'user_id', None),
|
||||||
"apply_id": getattr(stmt, 'apply_id', None),
|
"apply_id": getattr(stmt, 'apply_id', None),
|
||||||
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
||||||
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
edges.append({
|
edges.append({
|
||||||
"summary_id": s.id,
|
"summary_id": s.id,
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
"group_id": s.group_id,
|
"end_user_id": s.end_user_id,
|
||||||
"run_id": s.run_id,
|
"run_id": s.run_id,
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
|
|||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
|
||||||
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
|
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||||
"""Delete all nodes in the database."""
|
"""Delete all nodes in the database."""
|
||||||
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
|
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||||
print(f"All group_id: {group_id} node and edge deleted successfully")
|
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
for dialogue in dialogues:
|
for dialogue in dialogues:
|
||||||
flattened_dialogues.append({
|
flattened_dialogues.append({
|
||||||
"id": dialogue.id,
|
"id": dialogue.id,
|
||||||
"group_id": dialogue.group_id,
|
"end_user_id": dialogue.end_user_id,
|
||||||
"user_id": dialogue.user_id,
|
|
||||||
"apply_id": dialogue.apply_id,
|
|
||||||
"run_id": dialogue.run_id,
|
"run_id": dialogue.run_id,
|
||||||
"ref_id": dialogue.ref_id,
|
"ref_id": dialogue.ref_id,
|
||||||
"name": dialogue.name,
|
"name": dialogue.name,
|
||||||
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
flattened_statement = {
|
flattened_statement = {
|
||||||
"id": statement.id,
|
"id": statement.id,
|
||||||
"name": statement.name,
|
"name": statement.name,
|
||||||
"group_id": statement.group_id,
|
"end_user_id": statement.end_user_id,
|
||||||
"user_id": statement.user_id,
|
|
||||||
"apply_id": statement.apply_id,
|
|
||||||
"run_id": statement.run_id,
|
"run_id": statement.run_id,
|
||||||
"chunk_id": statement.chunk_id,
|
"chunk_id": statement.chunk_id,
|
||||||
# "created_at": statement.created_at.isoformat(),
|
# "created_at": statement.created_at.isoformat(),
|
||||||
@@ -101,6 +97,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
|
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
|
||||||
|
# 添加 speaker 字段(用于基于角色的情绪提取)
|
||||||
|
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
|
||||||
# 添加情绪字段处理
|
# 添加情绪字段处理
|
||||||
"emotion_type": statement.emotion_type,
|
"emotion_type": statement.emotion_type,
|
||||||
"emotion_intensity": statement.emotion_intensity,
|
"emotion_intensity": statement.emotion_intensity,
|
||||||
@@ -152,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
flattened_chunk = {
|
flattened_chunk = {
|
||||||
"id": chunk.id,
|
"id": chunk.id,
|
||||||
"name": chunk.name,
|
"name": chunk.name,
|
||||||
"group_id": chunk.group_id,
|
"end_user_id": chunk.end_user_id,
|
||||||
"user_id": chunk.user_id,
|
|
||||||
"apply_id": chunk.apply_id,
|
|
||||||
"run_id": chunk.run_id,
|
"run_id": chunk.run_id,
|
||||||
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
||||||
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
||||||
@@ -163,7 +159,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
||||||
"sequence_number": chunk.sequence_number,
|
"sequence_number": chunk.sequence_number,
|
||||||
"start_index": metadata.get("start_index"),
|
"start_index": metadata.get("start_index"),
|
||||||
"end_index": metadata.get("end_index")
|
"end_index": metadata.get("end_index"),
|
||||||
|
# 添加 speaker 字段(用于基于角色的情绪提取)
|
||||||
|
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
|
||||||
}
|
}
|
||||||
flattened_chunks.append(flattened_chunk)
|
flattened_chunks.append(flattened_chunk)
|
||||||
|
|
||||||
@@ -202,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
flattened.append({
|
flattened.append({
|
||||||
"id": s.id,
|
"id": s.id,
|
||||||
"name": s.name,
|
"name": s.name,
|
||||||
"group_id": s.group_id,
|
"end_user_id": s.end_user_id,
|
||||||
"user_id": s.user_id,
|
|
||||||
"apply_id": s.apply_id,
|
|
||||||
"run_id": s.run_id,
|
"run_id": s.run_id,
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> results = await repository.find(
|
>>> results = await repository.find(
|
||||||
... {"group_id": "group_123", "user_id": "user_456"},
|
... {"end_user_id": "group_123", "user_id": "user_456"},
|
||||||
... limit=50
|
... limit=50
|
||||||
... )
|
... )
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
|
|||||||
UNWIND $dialogues AS dialogue
|
UNWIND $dialogues AS dialogue
|
||||||
MERGE (n:Dialogue {id: dialogue.id})
|
MERGE (n:Dialogue {id: dialogue.id})
|
||||||
SET n.uuid = coalesce(n.uuid, dialogue.id),
|
SET n.uuid = coalesce(n.uuid, dialogue.id),
|
||||||
n.group_id = dialogue.group_id,
|
n.end_user_id = dialogue.end_user_id,
|
||||||
n.user_id = dialogue.user_id,
|
|
||||||
n.apply_id = dialogue.apply_id,
|
|
||||||
n.run_id = dialogue.run_id,
|
n.run_id = dialogue.run_id,
|
||||||
n.ref_id = dialogue.ref_id,
|
n.ref_id = dialogue.ref_id,
|
||||||
n.created_at = dialogue.created_at,
|
n.created_at = dialogue.created_at,
|
||||||
@@ -22,9 +20,7 @@ SET s += {
|
|||||||
id: statement.id,
|
id: statement.id,
|
||||||
run_id: statement.run_id,
|
run_id: statement.run_id,
|
||||||
chunk_id: statement.chunk_id,
|
chunk_id: statement.chunk_id,
|
||||||
group_id: statement.group_id,
|
end_user_id: statement.end_user_id,
|
||||||
user_id: statement.user_id,
|
|
||||||
apply_id: statement.apply_id,
|
|
||||||
stmt_type: statement.stmt_type,
|
stmt_type: statement.stmt_type,
|
||||||
statement: statement.statement,
|
statement: statement.statement,
|
||||||
emotion_intensity: statement.emotion_intensity,
|
emotion_intensity: statement.emotion_intensity,
|
||||||
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
|
|||||||
SET c += {
|
SET c += {
|
||||||
id: chunk.id,
|
id: chunk.id,
|
||||||
name: chunk.name,
|
name: chunk.name,
|
||||||
group_id: chunk.group_id,
|
end_user_id: chunk.end_user_id,
|
||||||
user_id: chunk.user_id,
|
|
||||||
apply_id: chunk.apply_id,
|
|
||||||
run_id: chunk.run_id,
|
run_id: chunk.run_id,
|
||||||
created_at: chunk.created_at,
|
created_at: chunk.created_at,
|
||||||
expired_at: chunk.expired_at,
|
expired_at: chunk.expired_at,
|
||||||
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
|
|||||||
UNWIND $entities AS entity
|
UNWIND $entities AS entity
|
||||||
MERGE (e:ExtractedEntity {id: entity.id})
|
MERGE (e:ExtractedEntity {id: entity.id})
|
||||||
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
|
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
|
||||||
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
|
e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
|
||||||
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
|
|
||||||
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
|
|
||||||
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
|
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
|
||||||
e.created_at = CASE
|
e.created_at = CASE
|
||||||
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
||||||
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
|
|||||||
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
|
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
|
||||||
ENTITY_RELATIONSHIP_SAVE = """
|
ENTITY_RELATIONSHIP_SAVE = """
|
||||||
UNWIND $relationships AS rel
|
UNWIND $relationships AS rel
|
||||||
// Match entities by stable id within group, do not constrain by run_id
|
// Match entities by stable id within end_user_id, do not constrain by run_id
|
||||||
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
|
MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
|
||||||
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
|
MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
|
||||||
// Avoid duplicate edges across runs for the same endpoints
|
// Avoid duplicate edges across runs for the same endpoints
|
||||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||||
SET r.predicate = rel.predicate,
|
SET r.predicate = rel.predicate,
|
||||||
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
|
|||||||
r.created_at = rel.created_at,
|
r.created_at = rel.created_at,
|
||||||
r.expired_at = rel.expired_at,
|
r.expired_at = rel.expired_at,
|
||||||
r.run_id = rel.run_id,
|
r.run_id = rel.run_id,
|
||||||
r.group_id = rel.group_id
|
r.end_user_id = rel.end_user_id
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
|
|||||||
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
||||||
SET e += {
|
SET e += {
|
||||||
name: entity.name,
|
name: entity.name,
|
||||||
group_id: entity.group_id,
|
end_user_id: entity.end_user_id,
|
||||||
run_id: entity.run_id,
|
run_id: entity.run_id,
|
||||||
description: entity.description,
|
description: entity.description,
|
||||||
chunk_id: entity.chunk_id,
|
chunk_id: entity.chunk_id,
|
||||||
@@ -175,11 +167,11 @@ RETURN e.id AS id
|
|||||||
SAVE_STRONG_TRIPLE_ENTITIES = """
|
SAVE_STRONG_TRIPLE_ENTITIES = """
|
||||||
UNWIND $items AS item
|
UNWIND $items AS item
|
||||||
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
||||||
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
|
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||||
// Independent strong flag
|
// Independent strong flag
|
||||||
SET s.is_strong = true
|
SET s.is_strong = true
|
||||||
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
||||||
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
|
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||||
// Independent strong flag
|
// Independent strong flag
|
||||||
SET o.is_strong = true
|
SET o.is_strong = true
|
||||||
"""
|
"""
|
||||||
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
|
|||||||
// 仅按端点去重,关系属性可更新
|
// 仅按端点去重,关系属性可更新
|
||||||
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
||||||
SET e.uuid = edge.id,
|
SET e.uuid = edge.id,
|
||||||
e.group_id = edge.group_id,
|
e.end_user_id = edge.end_user_id,
|
||||||
e.created_at = edge.created_at,
|
e.created_at = edge.created_at,
|
||||||
e.expired_at = edge.expired_at
|
e.expired_at = edge.expired_at
|
||||||
RETURN e.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
|
|||||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||||
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
|
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
|
||||||
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
||||||
SET e.group_id = edge.group_id,
|
SET e.end_user_id = edge.end_user_id,
|
||||||
e.run_id = edge.run_id,
|
e.run_id = edge.run_id,
|
||||||
e.created_at = edge.created_at,
|
e.created_at = edge.created_at,
|
||||||
e.expired_at = edge.expired_at
|
e.expired_at = edge.expired_at
|
||||||
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
|
|||||||
STATEMENT_ENTITY_EDGE_SAVE = """
|
STATEMENT_ENTITY_EDGE_SAVE = """
|
||||||
UNWIND $relationships AS rel
|
UNWIND $relationships AS rel
|
||||||
// Statement nodes are per-run; keep run_id constraint on statements
|
// Statement nodes are per-run; keep run_id constraint on statements
|
||||||
// Statement nodes are per-run; keep run_id constraint on statements
|
|
||||||
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
|
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
|
||||||
// Entities are shared across runs within a group; do not constrain by run_id
|
// Entities are shared across runs within end_user_id; do not constrain by run_id
|
||||||
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
|
MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
|
||||||
// Avoid duplicate edges across runs for same endpoints
|
// Avoid duplicate edges across runs for same endpoints
|
||||||
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
||||||
SET r.group_id = rel.group_id,
|
SET r.end_user_id = rel.end_user_id,
|
||||||
r.run_id = rel.run_id,
|
r.run_id = rel.run_id,
|
||||||
r.created_at = rel.created_at,
|
r.created_at = rel.created_at,
|
||||||
r.expired_at = rel.expired_at,
|
r.expired_at = rel.expired_at,
|
||||||
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
|
|||||||
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
||||||
YIELD node AS e, score
|
YIELD node AS e, score
|
||||||
WHERE e.name_embedding IS NOT NULL
|
WHERE e.name_embedding IS NOT NULL
|
||||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
RETURN e.id AS id,
|
RETURN e.id AS id,
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.end_user_id AS end_user_id,
|
||||||
e.entity_type AS entity_type,
|
e.entity_type AS entity_type,
|
||||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
|
|||||||
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
||||||
YIELD node AS s, score
|
YIELD node AS s, score
|
||||||
WHERE s.statement_embedding IS NOT NULL
|
WHERE s.statement_embedding IS NOT NULL
|
||||||
AND ($group_id IS NULL OR s.group_id = $group_id)
|
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
RETURN s.id AS id,
|
RETURN s.id AS id,
|
||||||
s.statement AS statement,
|
s.statement AS statement,
|
||||||
s.group_id AS group_id,
|
s.end_user_id AS end_user_id,
|
||||||
s.chunk_id AS chunk_id,
|
s.chunk_id AS chunk_id,
|
||||||
s.created_at AS created_at,
|
s.created_at AS created_at,
|
||||||
s.expired_at AS expired_at,
|
s.expired_at AS expired_at,
|
||||||
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
|
|||||||
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
||||||
YIELD node AS c, score
|
YIELD node AS c, score
|
||||||
WHERE c.chunk_embedding IS NOT NULL
|
WHERE c.chunk_embedding IS NOT NULL
|
||||||
AND ($group_id IS NULL OR c.group_id = $group_id)
|
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
RETURN c.id AS chunk_id,
|
RETURN c.id AS chunk_id,
|
||||||
c.group_id AS group_id,
|
c.end_user_id AS end_user_id,
|
||||||
c.content AS content,
|
c.content AS content,
|
||||||
c.dialog_id AS dialog_id,
|
c.dialog_id AS dialog_id,
|
||||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||||
@@ -292,12 +283,12 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
RETURN s.id AS id,
|
RETURN s.id AS id,
|
||||||
s.statement AS statement,
|
s.statement AS statement,
|
||||||
s.group_id AS group_id,
|
s.end_user_id AS end_user_id,
|
||||||
s.chunk_id AS chunk_id,
|
s.chunk_id AS chunk_id,
|
||||||
s.created_at AS created_at,
|
s.created_at AS created_at,
|
||||||
s.expired_at AS expired_at,
|
s.expired_at AS expired_at,
|
||||||
@@ -316,15 +307,13 @@ LIMIT $limit
|
|||||||
# 查询实体名称包含指定字符串的实体
|
# 查询实体名称包含指定字符串的实体
|
||||||
SEARCH_ENTITIES_BY_NAME = """
|
SEARCH_ENTITIES_BY_NAME = """
|
||||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||||
WHERE ($group_id IS NULL OR e.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
RETURN e.id AS id,
|
RETURN e.id AS id,
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.end_user_id AS end_user_id,
|
||||||
e.entity_type AS entity_type,
|
e.entity_type AS entity_type,
|
||||||
e.apply_id AS apply_id,
|
|
||||||
e.user_id AS user_id,
|
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
e.expired_at AS expired_at,
|
e.expired_at AS expired_at,
|
||||||
e.entity_idx AS entity_idx,
|
e.entity_idx AS entity_idx,
|
||||||
@@ -347,11 +336,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_CHUNKS_BY_CONTENT = """
|
SEARCH_CHUNKS_BY_CONTENT = """
|
||||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
RETURN c.id AS chunk_id,
|
RETURN c.id AS chunk_id,
|
||||||
c.group_id AS group_id,
|
c.end_user_id AS end_user_id,
|
||||||
c.content AS content,
|
c.content AS content,
|
||||||
c.dialog_id AS dialog_id,
|
c.dialog_id AS dialog_id,
|
||||||
c.sequence_number AS sequence_number,
|
c.sequence_number AS sequence_number,
|
||||||
@@ -413,10 +402,10 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||||
MATCH (d:Dialogue)
|
MATCH (d:Dialogue)
|
||||||
WHERE ($group_id IS NULL OR d.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
||||||
AND d.id = $dialog_id
|
AND d.id = $dialog_id
|
||||||
RETURN d.id AS dialog_id,
|
RETURN d.id AS dialog_id,
|
||||||
d.group_id AS group_id,
|
d.end_user_id AS end_user_id,
|
||||||
d.content AS content,
|
d.content AS content,
|
||||||
d.created_at AS created_at,
|
d.created_at AS created_at,
|
||||||
d.expired_at AS expired_at
|
d.expired_at AS expired_at
|
||||||
@@ -426,10 +415,10 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID = """
|
SEARCH_CHUNK_BY_CHUNK_ID = """
|
||||||
MATCH (c:Chunk)
|
MATCH (c:Chunk)
|
||||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
AND c.id = $chunk_id
|
AND c.id = $chunk_id
|
||||||
RETURN c.id AS chunk_id,
|
RETURN c.id AS chunk_id,
|
||||||
c.group_id AS group_id,
|
c.end_user_id AS end_user_id,
|
||||||
c.content AS content,
|
c.content AS content,
|
||||||
c.dialog_id AS dialog_id,
|
c.dialog_id AS dialog_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
@@ -441,18 +430,14 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_BY_TEMPORAL = """
|
SEARCH_STATEMENTS_BY_TEMPORAL = """
|
||||||
MATCH (s:Statement)
|
MATCH (s:Statement)
|
||||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
|
||||||
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
|
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
|
||||||
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
|
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
|
||||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||||
RETURN s.id AS id,
|
RETURN s.id AS id,
|
||||||
s.statement AS statement,
|
s.statement AS statement,
|
||||||
s.group_id AS group_id,
|
s.end_user_id AS end_user_id,
|
||||||
s.apply_id AS apply_id,
|
|
||||||
s.user_id AS user_id,
|
|
||||||
s.chunk_id AS chunk_id,
|
s.chunk_id AS chunk_id,
|
||||||
s.created_at AS created_at,
|
s.created_at AS created_at,
|
||||||
s.valid_at AS valid_at,
|
s.valid_at AS valid_at,
|
||||||
@@ -468,9 +453,7 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
|
||||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||||
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
|||||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
RETURN s.id AS id,
|
RETURN s.id AS id,
|
||||||
s.statement AS statement,
|
s.statement AS statement,
|
||||||
s.group_id AS group_id,
|
s.end_user_id AS end_user_id,
|
||||||
s.apply_id AS apply_id,
|
|
||||||
s.user_id AS user_id,
|
|
||||||
s.chunk_id AS chunk_id,
|
s.chunk_id AS chunk_id,
|
||||||
s.created_at AS created_at,
|
s.created_at AS created_at,
|
||||||
s.valid_at AS valid_at,
|
s.valid_at AS valid_at,
|
||||||
@@ -499,15 +480,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_BY_CREATED_AT = """
|
SEARCH_STATEMENTS_BY_CREATED_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
|
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -519,15 +496,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_BY_VALID_AT = """
|
SEARCH_STATEMENTS_BY_VALID_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
|
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -539,15 +512,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_G_CREATED_AT = """
|
SEARCH_STATEMENTS_G_CREATED_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
|
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -559,15 +528,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_L_CREATED_AT = """
|
SEARCH_STATEMENTS_L_CREATED_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
|
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -579,15 +544,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_G_VALID_AT = """
|
SEARCH_STATEMENTS_G_VALID_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
|
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -599,15 +560,11 @@ LIMIT $limit
|
|||||||
|
|
||||||
SEARCH_STATEMENTS_L_VALID_AT = """
|
SEARCH_STATEMENTS_L_VALID_AT = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
|
||||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
|
||||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
|
||||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
|
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
|
||||||
RETURN n.id AS id,
|
RETURN n.id AS id,
|
||||||
n.statement AS statement,
|
n.statement AS statement,
|
||||||
n.group_id AS group_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
|
||||||
n.user_id AS user_id,
|
|
||||||
n.chunk_id AS chunk_id,
|
n.chunk_id AS chunk_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.valid_at AS valid_at,
|
n.valid_at AS valid_at,
|
||||||
@@ -665,18 +622,18 @@ LIMIT $limit
|
|||||||
|
|
||||||
# 根据id修改句子的invalid_at的值
|
# 根据id修改句子的invalid_at的值
|
||||||
UPDATE_STATEMENT_INVALID_AT = """
|
UPDATE_STATEMENT_INVALID_AT = """
|
||||||
MATCH (n:Statement {group_id: $group_id, id: $id})
|
MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
|
||||||
SET n.invalid_at = $new_invalid_at
|
SET n.invalid_at = $new_invalid_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# MemorySummary keyword search using fulltext index
|
# MemorySummary keyword search using fulltext index
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||||
WHERE ($group_id IS NULL OR m.group_id = $group_id)
|
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||||
RETURN m.id AS id,
|
RETURN m.id AS id,
|
||||||
m.name AS name,
|
m.name AS name,
|
||||||
m.group_id AS group_id,
|
m.end_user_id AS end_user_id,
|
||||||
m.dialog_id AS dialog_id,
|
m.dialog_id AS dialog_id,
|
||||||
m.chunk_ids AS chunk_ids,
|
m.chunk_ids AS chunk_ids,
|
||||||
m.content AS content,
|
m.content AS content,
|
||||||
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
|
|||||||
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
||||||
YIELD node AS m, score
|
YIELD node AS m, score
|
||||||
WHERE m.summary_embedding IS NOT NULL
|
WHERE m.summary_embedding IS NOT NULL
|
||||||
AND ($group_id IS NULL OR m.group_id = $group_id)
|
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||||
RETURN m.id AS id,
|
RETURN m.id AS id,
|
||||||
m.name AS name,
|
m.name AS name,
|
||||||
m.group_id AS group_id,
|
m.end_user_id AS end_user_id,
|
||||||
m.dialog_id AS dialog_id,
|
m.dialog_id AS dialog_id,
|
||||||
m.chunk_ids AS chunk_ids,
|
m.chunk_ids AS chunk_ids,
|
||||||
m.content AS content,
|
m.content AS content,
|
||||||
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
|
|||||||
SET m += {
|
SET m += {
|
||||||
id: summary.id,
|
id: summary.id,
|
||||||
name: summary.name,
|
name: summary.name,
|
||||||
group_id: summary.group_id,
|
end_user_id: summary.end_user_id,
|
||||||
user_id: summary.user_id,
|
|
||||||
apply_id: summary.apply_id,
|
|
||||||
run_id: summary.run_id,
|
run_id: summary.run_id,
|
||||||
created_at: summary.created_at,
|
created_at: summary.created_at,
|
||||||
expired_at: summary.expired_at,
|
expired_at: summary.expired_at,
|
||||||
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
|
|||||||
|
|
||||||
neo4j_statement_part = '''
|
neo4j_statement_part = '''
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE n.group_id = "{}"
|
WHERE n.end_user_id = "{}"
|
||||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||||
RETURN
|
RETURN
|
||||||
n.statement as statement_name,
|
n.statement as statement_name,
|
||||||
@@ -824,7 +779,7 @@ RETURN
|
|||||||
'''
|
'''
|
||||||
neo4j_statement_all = '''
|
neo4j_statement_all = '''
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE n.group_id = "{}"
|
WHERE n.end_user_id = "{}"
|
||||||
RETURN
|
RETURN
|
||||||
n.statement as statement_name,
|
n.statement as statement_name,
|
||||||
n.id as statement_id
|
n.id as statement_id
|
||||||
@@ -832,7 +787,7 @@ RETURN
|
|||||||
'''
|
'''
|
||||||
neo4j_query_part = """
|
neo4j_query_part = """
|
||||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||||
WHERE n.group_id = "{}"
|
WHERE n.end_user_id = "{}"
|
||||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||||
WITH DISTINCT m
|
WITH DISTINCT m
|
||||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||||
@@ -853,7 +808,7 @@ neo4j_query_part = """
|
|||||||
"""
|
"""
|
||||||
neo4j_query_all = """
|
neo4j_query_all = """
|
||||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||||
WHERE n.group_id = "{}"
|
WHERE n.end_user_id = "{}"
|
||||||
WITH DISTINCT m
|
WITH DISTINCT m
|
||||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||||
RETURN
|
RETURN
|
||||||
@@ -1027,14 +982,14 @@ RETURN DISTINCT
|
|||||||
|
|
||||||
Memory_Space_User="""
|
Memory_Space_User="""
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n)-[r]->(m)
|
||||||
WHERE n.group_id = $group_id AND m.name="用户"
|
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||||
return DISTINCT elementId(m) as id
|
return DISTINCT elementId(m) as id
|
||||||
"""
|
"""
|
||||||
Memory_Space_Entity="""
|
Memory_Space_Entity="""
|
||||||
MATCH (n)-[]-(m)
|
MATCH (n)-[]-(m)
|
||||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||||
RETURN
|
RETURN
|
||||||
DISTINCT m.name as name,m.group_id as group_id
|
DISTINCT m.name as name,m.end_user_id as end_user_id
|
||||||
"""
|
"""
|
||||||
Memory_Space_Associative="""
|
Memory_Space_Associative="""
|
||||||
MATCH (u)-[]-(x)-[]-(h)
|
MATCH (u)-[]-(x)-[]-(h)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
"""对话仓储
|
"""对话仓储
|
||||||
|
|
||||||
管理对话节点的创建、查询、更新和删除操作。
|
管理对话节点的创建、查询、更新和删除操作。
|
||||||
提供按group_id、user_id、ref_id等条件查询对话的方法。
|
提供按end_user_id、user_id、ref_id等条件查询对话的方法。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
connector: Neo4j连接器实例
|
connector: Neo4j连接器实例
|
||||||
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
|
|
||||||
return DialogueNode(**n)
|
return DialogueNode(**n)
|
||||||
|
|
||||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
|
async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||||
"""根据group_id查询对话
|
"""根据end_user_id查询对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
limit: 返回结果的最大数量
|
limit: 返回结果的最大数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[DialogueNode]: 对话列表
|
List[DialogueNode]: 对话列表
|
||||||
"""
|
"""
|
||||||
return await self.find({"group_id": group_id}, limit=limit)
|
return await self.find({"end_user_id": end_user_id}, limit=limit)
|
||||||
|
|
||||||
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
|
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||||
"""根据user_id查询对话
|
"""根据user_id查询对话
|
||||||
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
|
|
||||||
async def find_by_group_and_user(
|
async def find_by_group_and_user(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
) -> List[DialogueNode]:
|
) -> List[DialogueNode]:
|
||||||
"""根据group_id和user_id查询对话
|
"""根据end_user_id和user_id查询对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
limit: 返回结果的最大数量
|
limit: 返回结果的最大数量
|
||||||
|
|
||||||
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
List[DialogueNode]: 对话列表
|
List[DialogueNode]: 对话列表
|
||||||
"""
|
"""
|
||||||
return await self.find(
|
return await self.find(
|
||||||
{"group_id": group_id, "user_id": user_id},
|
{"end_user_id": end_user_id, "user_id": user_id},
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
async def find_recent_dialogs(
|
async def find_recent_dialogs(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
days: int = 7,
|
days: int = 7,
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
) -> List[DialogueNode]:
|
) -> List[DialogueNode]:
|
||||||
"""查询最近的对话
|
"""查询最近的对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
days: 查询最近多少天的对话
|
days: 查询最近多少天的对话
|
||||||
limit: 返回结果的最大数量
|
limit: 返回结果的最大数量
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{self.node_label})
|
MATCH (n:{self.node_label})
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||||
RETURN n
|
RETURN n
|
||||||
ORDER BY n.created_at DESC
|
ORDER BY n.created_at DESC
|
||||||
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
"""
|
"""
|
||||||
results = await self.connector.execute_query(
|
results = await self.connector.execute_query(
|
||||||
query,
|
query,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
days=days,
|
days=days,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
@@ -164,16 +164,16 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
async def find_by_config_and_group(
|
async def find_by_config_and_group(
|
||||||
self,
|
self,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
) -> List[DialogueNode]:
|
) -> List[DialogueNode]:
|
||||||
"""根据config_id和group_id查询对话
|
"""根据config_id和end_user_id查询对话
|
||||||
|
|
||||||
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
|
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_id: 配置ID
|
config_id: 配置ID
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
limit: 返回结果的最大数量
|
limit: 返回结果的最大数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class EmotionRepository:
|
|||||||
|
|
||||||
async def get_emotion_tags(
|
async def get_emotion_tags(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
@@ -51,7 +51,7 @@ class EmotionRepository:
|
|||||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 用户组ID(宿主ID)
|
end_user_id: 用户组ID(宿主ID)
|
||||||
emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)
|
emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)
|
||||||
start_date: 可选的开始日期(ISO格式字符串)
|
start_date: 可选的开始日期(ISO格式字符串)
|
||||||
end_date: 可选的结束日期(ISO格式字符串)
|
end_date: 可选的结束日期(ISO格式字符串)
|
||||||
@@ -65,8 +65,8 @@ class EmotionRepository:
|
|||||||
- avg_intensity: 平均强度
|
- avg_intensity: 平均强度
|
||||||
"""
|
"""
|
||||||
# 构建查询条件
|
# 构建查询条件
|
||||||
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
|
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
|
||||||
params = {"group_id": group_id, "limit": limit}
|
params = {"end_user_id": end_user_id, "limit": limit}
|
||||||
|
|
||||||
if emotion_type:
|
if emotion_type:
|
||||||
where_clauses.append("s.emotion_type = $emotion_type")
|
where_clauses.append("s.emotion_type = $emotion_type")
|
||||||
@@ -119,7 +119,7 @@ class EmotionRepository:
|
|||||||
|
|
||||||
async def get_emotion_wordcloud(
|
async def get_emotion_wordcloud(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
limit: int = 50
|
limit: int = 50
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
@@ -128,7 +128,7 @@ class EmotionRepository:
|
|||||||
查询情绪关键词及其频率,用于生成词云可视化。
|
查询情绪关键词及其频率,用于生成词云可视化。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 用户组ID(宿主ID)
|
end_user_id: 用户组ID(宿主ID)
|
||||||
emotion_type: 可选的情绪类型过滤
|
emotion_type: 可选的情绪类型过滤
|
||||||
limit: 返回关键词的最大数量
|
limit: 返回关键词的最大数量
|
||||||
|
|
||||||
@@ -140,8 +140,8 @@ class EmotionRepository:
|
|||||||
- avg_intensity: 平均强度
|
- avg_intensity: 平均强度
|
||||||
"""
|
"""
|
||||||
# 构建查询条件
|
# 构建查询条件
|
||||||
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
|
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
|
||||||
params = {"group_id": group_id, "limit": limit}
|
params = {"end_user_id": end_user_id, "limit": limit}
|
||||||
|
|
||||||
if emotion_type:
|
if emotion_type:
|
||||||
where_clauses.append("s.emotion_type = $emotion_type")
|
where_clauses.append("s.emotion_type = $emotion_type")
|
||||||
@@ -186,7 +186,7 @@ class EmotionRepository:
|
|||||||
|
|
||||||
async def get_emotions_in_range(
|
async def get_emotions_in_range(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
time_range: str = "30d"
|
time_range: str = "30d"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取时间范围内的情绪数据
|
"""获取时间范围内的情绪数据
|
||||||
@@ -194,7 +194,7 @@ class EmotionRepository:
|
|||||||
查询指定时间范围内的所有情绪数据,用于健康指数计算。
|
查询指定时间范围内的所有情绪数据,用于健康指数计算。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 用户组ID(宿主ID)
|
end_user_id: 用户组ID(宿主ID)
|
||||||
time_range: 时间范围(7d/30d/90d)
|
time_range: 时间范围(7d/30d/90d)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -214,7 +214,7 @@ class EmotionRepository:
|
|||||||
# 优化的 Cypher 查询:使用字符串比较避免时区问题
|
# 优化的 Cypher 查询:使用字符串比较避免时区问题
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:Statement)
|
MATCH (s:Statement)
|
||||||
WHERE s.group_id = $group_id
|
WHERE s.end_user_id = $end_user_id
|
||||||
AND s.emotion_type IS NOT NULL
|
AND s.emotion_type IS NOT NULL
|
||||||
AND s.created_at >= $start_date
|
AND s.created_at >= $start_date
|
||||||
RETURN s.id as statement_id,
|
RETURN s.id as statement_id,
|
||||||
|
|||||||
@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
|
|||||||
'created_at': edge.created_at.isoformat(),
|
'created_at': edge.created_at.isoformat(),
|
||||||
'expired_at': edge.expired_at.isoformat(),
|
'expired_at': edge.expired_at.isoformat(),
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'group_id': edge.group_id,
|
'end_user_id': edge.end_user_id,
|
||||||
'user_id': edge.user_id,
|
|
||||||
'apply_id': edge.apply_id,
|
|
||||||
}
|
}
|
||||||
all_relationships.append(relationship)
|
all_relationships.append(relationship)
|
||||||
|
|
||||||
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
|
|||||||
"id": edge.id,
|
"id": edge.id,
|
||||||
"source": edge.source,
|
"source": edge.source,
|
||||||
"target": edge.target,
|
"target": edge.target,
|
||||||
"group_id": edge.group_id,
|
"end_user_id": edge.end_user_id,
|
||||||
"user_id": edge.user_id,
|
|
||||||
"apply_id": edge.apply_id,
|
|
||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
|
|||||||
edge_data = {
|
edge_data = {
|
||||||
"source": edge.source,
|
"source": edge.source,
|
||||||
"target": edge.target,
|
"target": edge.target,
|
||||||
"group_id": edge.group_id,
|
"end_user_id": edge.end_user_id,
|
||||||
"user_id": edge.user_id,
|
|
||||||
"apply_id": edge.apply_id,
|
|
||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"connect_strength": edge.connect_strength,
|
"connect_strength": edge.connect_strength,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
|
|||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
nodes: List[Dict[str, Any]],
|
nodes: List[Dict[str, Any]],
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
|
|||||||
connector: Neo4j连接器
|
connector: Neo4j连接器
|
||||||
nodes: 节点列表,每个节点必须包含 'id' 字段
|
nodes: 节点列表,每个节点必须包含 'id' 字段
|
||||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
max_retries: 最大重试次数
|
max_retries: 最大重试次数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
|
|||||||
updated_nodes = await access_manager.record_batch_access(
|
updated_nodes = await access_manager.record_batch_access(
|
||||||
node_ids=unique_node_ids,
|
node_ids=unique_node_ids,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
|
|||||||
async def _update_search_results_activation(
|
async def _update_search_results_activation(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
results: Dict[str, List[Dict[str, Any]]],
|
results: Dict[str, List[Dict[str, Any]]],
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
更新搜索结果中所有知识节点的激活值
|
更新搜索结果中所有知识节点的激活值
|
||||||
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
|
|||||||
Args:
|
Args:
|
||||||
connector: Neo4j连接器
|
connector: Neo4j连接器
|
||||||
results: 搜索结果字典,包含不同类型节点的列表
|
results: 搜索结果字典,包含不同类型节点的列表
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
||||||
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
nodes=results[key],
|
nodes=results[key],
|
||||||
node_label=label,
|
node_label=label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
update_keys.append(key)
|
update_keys.append(key)
|
||||||
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
|
|||||||
async def search_graph(
|
async def search_graph(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
q: str,
|
q: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: List[str] = None,
|
include: List[str] = None,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -236,7 +236,7 @@ async def search_graph(
|
|||||||
Args:
|
Args:
|
||||||
connector: Neo4j connector
|
connector: Neo4j connector
|
||||||
q: Query text
|
q: Query text
|
||||||
group_id: Optional group filter
|
end_user_id: Optional group filter
|
||||||
limit: Max results per category
|
limit: Max results per category
|
||||||
include: List of categories to search (default: all)
|
include: List of categories to search (default: all)
|
||||||
|
|
||||||
@@ -254,7 +254,7 @@ async def search_graph(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||||
q=q,
|
q=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("statements")
|
task_keys.append("statements")
|
||||||
@@ -263,7 +263,7 @@ async def search_graph(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
q=q,
|
q=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("entities")
|
task_keys.append("entities")
|
||||||
@@ -272,7 +272,7 @@ async def search_graph(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
SEARCH_CHUNKS_BY_CONTENT,
|
||||||
q=q,
|
q=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("chunks")
|
task_keys.append("chunks")
|
||||||
@@ -281,7 +281,7 @@ async def search_graph(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
q=q,
|
q=q,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
@@ -308,7 +308,7 @@ async def search_graph(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -318,7 +318,7 @@ async def search_graph_by_embedding(
|
|||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
embedder_client,
|
embedder_client,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -330,7 +330,7 @@ async def search_graph_by_embedding(
|
|||||||
|
|
||||||
- Computes query embedding with the provided embedder_client
|
- Computes query embedding with the provided embedder_client
|
||||||
- Ranks by cosine similarity in Cypher
|
- Ranks by cosine similarity in Cypher
|
||||||
- Filters by group_id if provided
|
- Filters by end_user_id if provided
|
||||||
- Returns up to 'limit' per included type
|
- Returns up to 'limit' per included type
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
@@ -354,7 +354,7 @@ async def search_graph_by_embedding(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
STATEMENT_EMBEDDING_SEARCH,
|
STATEMENT_EMBEDDING_SEARCH,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("statements")
|
task_keys.append("statements")
|
||||||
@@ -364,7 +364,7 @@ async def search_graph_by_embedding(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
CHUNK_EMBEDDING_SEARCH,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("chunks")
|
task_keys.append("chunks")
|
||||||
@@ -374,7 +374,7 @@ async def search_graph_by_embedding(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
ENTITY_EMBEDDING_SEARCH,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("entities")
|
task_keys.append("entities")
|
||||||
@@ -384,7 +384,7 @@ async def search_graph_by_embedding(
|
|||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
@@ -421,7 +421,7 @@ async def search_graph_by_embedding(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
update_time = time.time() - update_start
|
update_time = time.time() - update_start
|
||||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||||
@@ -429,7 +429,7 @@ async def search_graph_by_embedding(
|
|||||||
return results
|
return results
|
||||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
entities: List[Dict[str, Any]],
|
entities: List[Dict[str, Any]],
|
||||||
use_contains_fallback: bool = True,
|
use_contains_fallback: bool = True,
|
||||||
batch_size: int = 500,
|
batch_size: int = 500,
|
||||||
@@ -437,7 +437,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
|||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||||
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
|
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
|
||||||
- 保留并发控制与返回结构(incoming_id -> [db_entity_props...]);
|
- 保留并发控制与返回结构(incoming_id -> [db_entity_props...]);
|
||||||
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
|
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
|
||||||
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
|
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
|
||||||
@@ -461,7 +461,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
|||||||
rows = await connector.execute_query(
|
rows = await connector.execute_query(
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
q=name,
|
q=name,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=100,
|
limit=100,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -485,7 +485,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
|||||||
rows = await connector.execute_query(
|
rows = await connector.execute_query(
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
q=name.lower(),
|
q=name.lower(),
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=100,
|
limit=100,
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
@@ -516,9 +516,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
|||||||
async def search_graph_by_keyword_temporal(
|
async def search_graph_by_keyword_temporal(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -531,7 +529,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements containing query_text created between start_date and end_date
|
- Matches statements containing query_text created between start_date and end_date
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
if not query_text:
|
if not query_text:
|
||||||
@@ -540,9 +538,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||||
q=query_text,
|
q=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
valid_date=valid_date,
|
valid_date=valid_date,
|
||||||
@@ -556,7 +552,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -564,9 +560,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
|
|
||||||
async def search_graph_by_temporal(
|
async def search_graph_by_temporal(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -579,14 +573,12 @@ async def search_graph_by_temporal(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements created between start_date and end_date
|
- Matches statements created between start_date and end_date
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
valid_date=valid_date,
|
valid_date=valid_date,
|
||||||
@@ -595,7 +587,7 @@ async def search_graph_by_temporal(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -603,7 +595,7 @@ async def search_graph_by_temporal(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -612,14 +604,14 @@ async def search_graph_by_temporal(
|
|||||||
async def search_graph_by_dialog_id(
|
async def search_graph_by_dialog_id(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
dialog_id: str,
|
dialog_id: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Dialogues.
|
Temporal search across Dialogues.
|
||||||
|
|
||||||
- Matches dialogues with dialog_id
|
- Matches dialogues with dialog_id
|
||||||
- Optionally filters by group_id
|
- Optionally filters by end_user_id
|
||||||
- Returns up to 'limit' dialogues
|
- Returns up to 'limit' dialogues
|
||||||
"""
|
"""
|
||||||
if not dialog_id:
|
if not dialog_id:
|
||||||
@@ -628,7 +620,7 @@ async def search_graph_by_dialog_id(
|
|||||||
|
|
||||||
dialogues = await connector.execute_query(
|
dialogues = await connector.execute_query(
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
dialog_id=dialog_id,
|
dialog_id=dialog_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
@@ -638,7 +630,7 @@ async def search_graph_by_dialog_id(
|
|||||||
async def search_graph_by_chunk_id(
|
async def search_graph_by_chunk_id(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
chunk_id : str,
|
chunk_id : str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
if not chunk_id:
|
if not chunk_id:
|
||||||
@@ -646,7 +638,7 @@ async def search_graph_by_chunk_id(
|
|||||||
return {"chunks": []}
|
return {"chunks": []}
|
||||||
chunks = await connector.execute_query(
|
chunks = await connector.execute_query(
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
chunk_id=chunk_id,
|
chunk_id=chunk_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
@@ -655,9 +647,9 @@ async def search_graph_by_chunk_id(
|
|||||||
|
|
||||||
async def search_graph_by_created_at(
|
async def search_graph_by_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
created_at: Optional[str] = None,
|
created_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -667,20 +659,20 @@ async def search_graph_by_created_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements created at created_at
|
- Matches statements created at created_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -688,16 +680,16 @@ async def search_graph_by_created_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def search_graph_by_valid_at(
|
async def search_graph_by_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
valid_at: Optional[str] = None,
|
valid_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -707,20 +699,20 @@ async def search_graph_by_valid_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements valid at valid_at
|
- Matches statements valid at valid_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -728,16 +720,16 @@ async def search_graph_by_valid_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def search_graph_g_created_at(
|
async def search_graph_g_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
created_at: Optional[str] = None,
|
created_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -747,20 +739,20 @@ async def search_graph_g_created_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements created at created_at
|
- Matches statements created at created_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -768,16 +760,16 @@ async def search_graph_g_created_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def search_graph_g_valid_at(
|
async def search_graph_g_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
valid_at: Optional[str] = None,
|
valid_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -787,20 +779,20 @@ async def search_graph_g_valid_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements valid at valid_at
|
- Matches statements valid at valid_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_G_VALID_AT,
|
SEARCH_STATEMENTS_G_VALID_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -808,16 +800,16 @@ async def search_graph_g_valid_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def search_graph_l_created_at(
|
async def search_graph_l_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
created_at: Optional[str] = None,
|
created_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -827,20 +819,20 @@ async def search_graph_l_created_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements created at created_at
|
- Matches statements created at created_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -848,16 +840,16 @@ async def search_graph_l_created_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def search_graph_l_valid_at(
|
async def search_graph_l_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
valid_at: Optional[str] = None,
|
valid_at: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -867,20 +859,20 @@ async def search_graph_l_valid_at(
|
|||||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||||
|
|
||||||
- Matches statements valid at valid_at
|
- Matches statements valid at valid_at
|
||||||
- Optionally filters by group_id, apply_id, user_id
|
- Optionally filters by end_user_id, apply_id, user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_L_VALID_AT,
|
SEARCH_STATEMENTS_L_VALID_AT,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=apply_id,
|
|
||||||
user_id=user_id,
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||||
print(f"查询结果为:\n{statements}")
|
print(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
@@ -888,7 +880,7 @@ async def search_graph_l_valid_at(
|
|||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
results=results,
|
results=results,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
"""Memory Summary Repository
|
"""Memory Summary Repository
|
||||||
|
|
||||||
Manages CRUD operations for MemorySummary nodes.
|
Manages CRUD operations for MemorySummary nodes.
|
||||||
Provides methods to query summaries by group_id, user_id, and time ranges.
|
Provides methods to query summaries by end_user_id, user_id, and time ranges.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
connector: Neo4j connector instance
|
connector: Neo4j connector instance
|
||||||
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
|
|
||||||
return dict(n)
|
return dict(n)
|
||||||
|
|
||||||
async def find_by_group_id(
|
async def find_by_end_user_id(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Query memory summaries by group_id
|
"""Query memory summaries by end_user_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group ID to filter by
|
end_user_id: Group ID to filter by
|
||||||
limit: Maximum number of results to return
|
limit: Maximum number of results to return
|
||||||
start_date: Optional start date filter
|
start_date: Optional start date filter
|
||||||
end_date: Optional end date filter
|
end_date: Optional end date filter
|
||||||
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{self.node_label})
|
MATCH (n:{self.node_label})
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {"group_id": group_id, "limit": limit}
|
params = {"end_user_id": end_user_id, "limit": limit}
|
||||||
|
|
||||||
# Add date range filters if provided
|
# Add date range filters if provided
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
|
|
||||||
async def find_by_group_and_user(
|
async def find_by_group_and_user(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Query memory summaries by both group_id and user_id
|
"""Query memory summaries by both end_user_id and user_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group ID to filter by
|
end_user_id: Group ID to filter by
|
||||||
user_id: User ID to filter by
|
user_id: User ID to filter by
|
||||||
limit: Maximum number of results to return
|
limit: Maximum number of results to return
|
||||||
start_date: Optional start date filter
|
start_date: Optional start date filter
|
||||||
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{self.node_label})
|
MATCH (n:{self.node_label})
|
||||||
WHERE n.group_id = $group_id AND n.user_id = $user_id
|
WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
|
params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
|
||||||
|
|
||||||
# Add date range filters if provided
|
# Add date range filters if provided
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
|
|
||||||
async def find_recent_summaries(
|
async def find_recent_summaries(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
days: int = 7,
|
days: int = 7,
|
||||||
limit: int = 1000
|
limit: int = 1000
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Query recent memory summaries
|
"""Query recent memory summaries
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group ID to filter by
|
end_user_id: Group ID to filter by
|
||||||
days: Number of recent days to query
|
days: Number of recent days to query
|
||||||
limit: Maximum number of results to return
|
limit: Maximum number of results to return
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{self.node_label})
|
MATCH (n:{self.node_label})
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||||
RETURN n
|
RETURN n
|
||||||
ORDER BY n.created_at DESC
|
ORDER BY n.created_at DESC
|
||||||
|
|||||||
@@ -141,14 +141,14 @@ class Neo4jConnector:
|
|||||||
async with self.driver.session(database="neo4j") as session:
|
async with self.driver.session(database="neo4j") as session:
|
||||||
return await session.execute_read(transaction_func, **kwargs)
|
return await session.execute_read(transaction_func, **kwargs)
|
||||||
|
|
||||||
async def delete_group(self, group_id: str):
|
async def delete_group(self, end_user_id: str):
|
||||||
"""删除指定组的所有数据
|
"""删除指定组的所有数据
|
||||||
|
|
||||||
删除所有属于指定group_id的节点和边。
|
删除所有属于指定end_user_id的节点和边。
|
||||||
这是一个危险操作,会永久删除数据。
|
这是一个危险操作,会永久删除数据。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 要删除的组ID
|
end_user_id: 要删除的组ID
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> connector = Neo4jConnector()
|
>>> connector = Neo4jConnector()
|
||||||
@@ -157,14 +157,14 @@ class Neo4jConnector:
|
|||||||
"""
|
"""
|
||||||
# 删除节点(DETACH DELETE会同时删除相关的边)
|
# 删除节点(DETACH DELETE会同时删除相关的边)
|
||||||
await self.driver.execute_query(
|
await self.driver.execute_query(
|
||||||
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
|
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
|
||||||
database="neo4j",
|
database="neo4j",
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
# 删除独立的边(如果有的话)
|
# 删除独立的边(如果有的话)
|
||||||
await self.driver.execute_query(
|
await self.driver.execute_query(
|
||||||
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
|
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
|
||||||
database="neo4j",
|
database="neo4j",
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
print(f"Group {group_id} deleted.")
|
print(f"Group {end_user_id} deleted.")
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
|||||||
"""陈述句仓储
|
"""陈述句仓储
|
||||||
|
|
||||||
管理陈述句节点的创建、查询、更新和删除操作。
|
管理陈述句节点的创建、查询、更新和删除操作。
|
||||||
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
|
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
connector: Neo4j连接器实例
|
connector: Neo4j连接器实例
|
||||||
|
|||||||
@@ -7,15 +7,11 @@ class UserInput(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
history: list[dict]
|
history: list[dict]
|
||||||
search_switch: str
|
search_switch: str
|
||||||
group_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Write_UserInput(BaseModel):
|
class Write_UserInput(BaseModel):
|
||||||
message: str
|
messages: list[dict]
|
||||||
group_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
class End_User_Information(BaseModel):
|
|
||||||
end_user_name: str # 这是要更新的用户名
|
|
||||||
id: str # 宿主ID,用于匹配条件
|
|
||||||
|
|||||||
@@ -10,11 +10,6 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.tools import tool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
|
|||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
|
from langchain.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
class KnowledgeRetrievalInput(BaseModel):
|
class KnowledgeRetrievalInput(BaseModel):
|
||||||
@@ -93,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
try:
|
try:
|
||||||
memory_content = asyncio.run(
|
memory_content = asyncio.run(
|
||||||
MemoryAgentService().read_memory(
|
MemoryAgentService().read_memory(
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=question,
|
message=question,
|
||||||
history=[],
|
history=[],
|
||||||
search_switch="2",
|
search_switch="2",
|
||||||
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"app.core.memory.agent.read_message",
|
"app.core.memory.agent.read_message",
|
||||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
result = task_service.get_task_memory_read_result(task.id)
|
# result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
# status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
# logger.info(f"读取任务状态:{status}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
# 调用仓储层查询
|
# 调用仓储层查询
|
||||||
tags = await self.emotion_repo.get_emotion_tags(
|
tags = await self.emotion_repo.get_emotion_tags(
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
emotion_type=emotion_type,
|
emotion_type=emotion_type,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
# 调用仓储层查询
|
# 调用仓储层查询
|
||||||
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
emotion_type=emotion_type,
|
emotion_type=emotion_type,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
# 获取时间范围内的情绪数据
|
# 获取时间范围内的情绪数据
|
||||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
time_range=time_range
|
time_range=time_range
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
# 3. 获取情绪数据用于模式分析
|
# 3. 获取情绪数据用于模式分析
|
||||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
time_range="30d"
|
time_range="30d"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
|
|||||||
# 查询用户的实体和标签
|
# 查询用户的实体和标签
|
||||||
query = """
|
query = """
|
||||||
MATCH (e:Entity)
|
MATCH (e:Entity)
|
||||||
WHERE e.group_id = $group_id
|
WHERE e.end_user_id = $end_user_id
|
||||||
RETURN e.name as name, e.type as type
|
RETURN e.name as name, e.type as type
|
||||||
ORDER BY e.created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
entities = await connector.execute_query(query, group_id=end_user_id)
|
entities = await connector.execute_query(query, end_user_id=end_user_id)
|
||||||
|
|
||||||
# 提取兴趣标签
|
# 提取兴趣标签
|
||||||
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
||||||
|
|||||||
@@ -10,27 +10,34 @@ import re
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
import redis
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
|
import redis
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
from app.core.memory.agent.utils.messages_tools import (
|
||||||
|
merge_multiple_search_results,
|
||||||
|
reorder_output_results,
|
||||||
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
|
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
|
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_base_service import Translation_English
|
from app.services.memory_base_service import Translation_English
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -49,25 +56,24 @@ _neo4j_connector = Neo4jConnector()
|
|||||||
class MemoryAgentService:
|
class MemoryAgentService:
|
||||||
"""Service for memory agent operations"""
|
"""Service for memory agent operations"""
|
||||||
|
|
||||||
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
|
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
if str(messages) == 'success':
|
if str(messages) == 'success':
|
||||||
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
||||||
duration=duration, details={"message_length": len(message)})
|
duration=duration, details={"message_length": len(message)})
|
||||||
return context
|
return context
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Write operation failed for group {group_id}")
|
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||||
|
|
||||||
# 记录失败的操作
|
# 记录失败的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="WRITE",
|
operation="WRITE",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
success=False,
|
success=False,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
error=f"写入失败: {messages[:100]}"
|
error=f"写入失败: {messages[:100]}"
|
||||||
@@ -260,12 +266,12 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group identifier (also used as end_user_id)
|
end_user_id: Group identifier (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
config_id: Configuration ID from database
|
config_id: Configuration ID from database
|
||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
@@ -281,15 +287,15 @@ class MemoryAgentService:
|
|||||||
# Resolve config_id if None using end_user's connected config
|
# Resolve config_id if None using end_user's connected config
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
try:
|
try:
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -309,20 +315,26 @@ class MemoryAgentService:
|
|||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
try:
|
|
||||||
if storage_type == "rag":
|
|
||||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
async with make_write_graph() as graph:
|
async with make_write_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
|
# Convert structured messages to LangChain messages
|
||||||
|
langchain_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg['role'] == 'user':
|
||||||
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
|
elif msg['role'] == 'assistant':
|
||||||
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
|
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
|
initial_state = {
|
||||||
"memory_config": memory_config}
|
"messages": langchain_messages,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"memory_config": memory_config
|
||||||
|
}
|
||||||
|
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
async for update_event in graph.astream(
|
async for update_event in graph.astream(
|
||||||
@@ -333,32 +345,73 @@ class MemoryAgentService:
|
|||||||
for node_name, node_data in update_event.items():
|
for node_name, node_data in update_event.items():
|
||||||
if 'save_neo4j' == node_name:
|
if 'save_neo4j' == node_name:
|
||||||
massages = node_data
|
massages = node_data
|
||||||
|
print(massages)
|
||||||
massagesstatus = massages.get('write_result')['status']
|
massagesstatus = massages.get('write_result')['status']
|
||||||
contents = massages.get('write_result')
|
contents = massages.get('write_result')
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
|
# Convert messages back to string for logging
|
||||||
except Exception as e:
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
# Ensure proper error handling and logging
|
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
# try:
|
||||||
if audit_logger:
|
# if storage_type == "rag":
|
||||||
duration = time.time() - start_time
|
# # For RAG storage, convert messages to single string
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
raise ValueError(error_msg)
|
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
|
# return result
|
||||||
|
# else:
|
||||||
|
# async with make_write_graph() as graph:
|
||||||
|
# config = {"configurable": {"thread_id": end_user_id}}
|
||||||
|
# # Convert structured messages to LangChain messages
|
||||||
|
# langchain_messages = []
|
||||||
|
# for msg in messages:
|
||||||
|
# if msg['role'] == 'user':
|
||||||
|
# langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
|
# elif msg['role'] == 'assistant':
|
||||||
|
# langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
|
#
|
||||||
|
# # 初始状态 - 包含所有必要字段
|
||||||
|
# initial_state = {
|
||||||
|
# "messages": langchain_messages,
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "memory_config": memory_config
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # 获取节点更新信息
|
||||||
|
# async for update_event in graph.astream(
|
||||||
|
# initial_state,
|
||||||
|
# stream_mode="updates",
|
||||||
|
# config=config
|
||||||
|
# ):
|
||||||
|
# for node_name, node_data in update_event.items():
|
||||||
|
# if 'save_neo4j' == node_name:
|
||||||
|
# massages = node_data
|
||||||
|
# massagesstatus = massages.get('write_result')['status']
|
||||||
|
# contents = massages.get('write_result')
|
||||||
|
# # Convert messages back to string for logging
|
||||||
|
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
|
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||||
|
# except Exception as e:
|
||||||
|
# # Ensure proper error handling and logging
|
||||||
|
# error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
# logger.error(error_msg)
|
||||||
|
# if audit_logger:
|
||||||
|
# duration = time.time() - start_time
|
||||||
|
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||||
|
# raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
config_id: Optional[str],
|
config_id: Optional[str],
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str
|
user_rag_memory_id: str) -> Dict:
|
||||||
) -> Dict:
|
|
||||||
"""
|
"""
|
||||||
Process read operation with config_id
|
Process read operation with config_id
|
||||||
|
|
||||||
@@ -368,7 +421,7 @@ class MemoryAgentService:
|
|||||||
- "2": Direct answer based on context
|
- "2": Direct answer based on context
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group identifier (also used as end_user_id)
|
end_user_id: Group identifier (also used as end_user_id)
|
||||||
message: User message
|
message: User message
|
||||||
history: Conversation history
|
history: Conversation history
|
||||||
search_switch: Search mode switch
|
search_switch: Search mode switch
|
||||||
@@ -386,21 +439,22 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ori_message= message
|
||||||
|
|
||||||
# Resolve config_id if None using end_user's connected config
|
# Resolve config_id if None using end_user's connected config
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
try:
|
try:
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||||
|
|
||||||
# 导入审计日志记录器
|
# 导入审计日志记录器
|
||||||
try:
|
try:
|
||||||
@@ -426,7 +480,7 @@ class MemoryAgentService:
|
|||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
success=False,
|
success=False,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
error=error_msg
|
error=error_msg
|
||||||
@@ -436,15 +490,16 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# Step 2: Prepare history
|
# Step 2: Prepare history
|
||||||
history.append({"role": "user", "content": message})
|
history.append({"role": "user", "content": message})
|
||||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||||
|
|
||||||
# Step 3: Initialize MCP client and execute read workflow
|
# Step 3: Initialize MCP client and execute read workflow
|
||||||
|
graph_exec_start = time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||||
"group_id": group_id
|
"end_user_id": end_user_id
|
||||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||||
"memory_config": memory_config}
|
"memory_config": memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
@@ -495,18 +550,72 @@ class MemoryAgentService:
|
|||||||
if summary_n and summary_n != [] and summary_n != {}:
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
_intermediate_outputs.append(summary_n)
|
_intermediate_outputs.append(summary_n)
|
||||||
|
|
||||||
|
graph_exec_time = time.time() - graph_exec_start
|
||||||
|
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||||||
|
|
||||||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||||
|
|
||||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||||
result = reorder_output_results(optimized_outputs)
|
result = reorder_output_results(optimized_outputs)
|
||||||
|
|
||||||
|
# 保存短期记忆到数据库
|
||||||
|
# 只有 search_switch 不为 "2"(快速检索)时才保存
|
||||||
|
try:
|
||||||
|
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||||
|
|
||||||
|
retrieved_content = []
|
||||||
|
repo = ShortTermMemoryRepository(db)
|
||||||
|
|
||||||
|
if str(search_switch) != "2":
|
||||||
|
for intermediate in _intermediate_outputs:
|
||||||
|
logger.debug(f"处理中间结果: {intermediate}")
|
||||||
|
intermediate_type = intermediate.get('type', '')
|
||||||
|
|
||||||
|
if intermediate_type == "search_result":
|
||||||
|
query = intermediate.get('query', '')
|
||||||
|
raw_results = intermediate.get('raw_results', {})
|
||||||
|
reranked_results = raw_results.get('reranked_results', [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||||
|
except Exception:
|
||||||
|
statements = []
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
statements = list(set(statements))
|
||||||
|
|
||||||
|
if query and statements:
|
||||||
|
retrieved_content.append({query: statements})
|
||||||
|
|
||||||
|
# 如果 retrieved_content 为空,设置为空字符串
|
||||||
|
if retrieved_content == []:
|
||||||
|
retrieved_content = ''
|
||||||
|
|
||||||
|
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||||
|
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||||
|
# 使用 upsert 方法
|
||||||
|
repo.upsert(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=message,
|
||||||
|
aimessages=summary,
|
||||||
|
retrieved_content=retrieved_content,
|
||||||
|
search_switch=str(search_switch)
|
||||||
|
)
|
||||||
|
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||||
|
|
||||||
|
except Exception as save_error:
|
||||||
|
# 保存失败不应该影响主流程,只记录错误
|
||||||
|
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||||
|
|
||||||
# Log successful operation
|
# Log successful operation
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
success=True,
|
success=True,
|
||||||
duration=duration
|
duration=duration
|
||||||
)
|
)
|
||||||
@@ -524,14 +633,56 @@ class MemoryAgentService:
|
|||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
success=False,
|
success=False,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
error=error_msg
|
error=error_msg
|
||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get standardized message list from user input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Write_UserInput object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Message list, each message contains role and content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages is empty or format is incorrect
|
||||||
|
"""
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
logger = get_api_logger()
|
||||||
|
|
||||||
|
if len(user_input.messages) == 0:
|
||||||
|
logger.error("Validation failed: Message list cannot be empty")
|
||||||
|
raise ValueError("Message list cannot be empty")
|
||||||
|
|
||||||
|
for idx, msg in enumerate(user_input.messages):
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||||
|
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||||
|
|
||||||
|
if 'role' not in msg:
|
||||||
|
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||||
|
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||||
|
|
||||||
|
if 'content' not in msg:
|
||||||
|
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||||
|
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||||
|
|
||||||
|
if msg['role'] not in ['user', 'assistant']:
|
||||||
|
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||||
|
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||||
|
|
||||||
|
if not msg['content'] or not msg['content'].strip():
|
||||||
|
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||||
|
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||||
|
|
||||||
|
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||||
|
return user_input.messages
|
||||||
|
|
||||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||||
"""
|
"""
|
||||||
@@ -559,7 +710,67 @@ class MemoryAgentService:
|
|||||||
logger.debug(f"Message type: {status}")
|
logger.debug(f"Message type: {status}")
|
||||||
return status
|
return status
|
||||||
|
|
||||||
# ==================== 新增的三个接口方法 ====================
|
async def generate_summary_from_retrieve(
|
||||||
|
self,
|
||||||
|
retrieve_info: str,
|
||||||
|
history: List[Dict],
|
||||||
|
query: str,
|
||||||
|
config_id: str,
|
||||||
|
db: Session
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
基于检索信息、历史对话和查询生成最终答案
|
||||||
|
|
||||||
|
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieve_info: 检索到的信息
|
||||||
|
history: 历史对话记录
|
||||||
|
query: 用户查询
|
||||||
|
config_id: 配置ID
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的答案文本
|
||||||
|
"""
|
||||||
|
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载配置
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入必要的模块
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
|
||||||
|
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
|
||||||
|
|
||||||
|
# 构建状态对象
|
||||||
|
state = {
|
||||||
|
"data": query,
|
||||||
|
"memory_config": memory_config
|
||||||
|
}
|
||||||
|
|
||||||
|
# 直接调用 summary_llm 函数
|
||||||
|
answer = await summary_llm(
|
||||||
|
state=state,
|
||||||
|
history=history,
|
||||||
|
retrieve_info=retrieve_info,
|
||||||
|
template_name='Retrieve_Summary_prompt.jinja2',
|
||||||
|
operation_name='retrieve_summary',
|
||||||
|
response_model=RetrieveSummaryResponse,
|
||||||
|
search_mode="1"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||||
|
return answer if answer else "信息不足,无法回答。"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||||
|
return "信息不足,无法回答。"
|
||||||
|
|
||||||
|
|
||||||
async def get_knowledge_type_stats(
|
async def get_knowledge_type_stats(
|
||||||
self,
|
self,
|
||||||
@@ -571,7 +782,7 @@ class MemoryAgentService:
|
|||||||
"""
|
"""
|
||||||
统计知识库类型分布,包含:
|
统计知识库类型分布,包含:
|
||||||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||||||
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤)
|
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
|
||||||
3. total: 所有类型的总和
|
3. total: 所有类型的总和
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -657,11 +868,11 @@ class MemoryAgentService:
|
|||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
end_user_id_str = str(end_user.id)
|
end_user_id_str = str(end_user.id)
|
||||||
memory_query = """
|
memory_query = """
|
||||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count
|
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
|
||||||
"""
|
"""
|
||||||
neo4j_result = await _neo4j_connector.execute_query(
|
neo4j_result = await _neo4j_connector.execute_query(
|
||||||
memory_query,
|
memory_query,
|
||||||
group_id=end_user_id_str,
|
end_user_id=end_user_id_str,
|
||||||
)
|
)
|
||||||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||||||
total_chunks += chunk_count
|
total_chunks += chunk_count
|
||||||
@@ -701,7 +912,7 @@ class MemoryAgentService:
|
|||||||
获取指定用户的热门记忆标签
|
获取指定用户的热门记忆标签
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- end_user_id: 用户ID(可选),对应Neo4j中的group_id字段
|
- end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段
|
||||||
- limit: 返回标签数量限制
|
- limit: 返回标签数量限制
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
@@ -711,7 +922,7 @@ class MemoryAgentService:
|
|||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度)
|
# by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度)
|
||||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||||
payload=[]
|
payload=[]
|
||||||
for tag, freq in tags:
|
for tag, freq in tags:
|
||||||
@@ -786,21 +997,21 @@ class MemoryAgentService:
|
|||||||
# 查询该用户的语句
|
# 查询该用户的语句
|
||||||
query = (
|
query = (
|
||||||
"MATCH (s:Statement) "
|
"MATCH (s:Statement) "
|
||||||
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL "
|
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
|
||||||
"RETURN s.statement AS statement "
|
"RETURN s.statement AS statement "
|
||||||
"ORDER BY s.created_at DESC LIMIT 100"
|
"ORDER BY s.created_at DESC LIMIT 100"
|
||||||
)
|
)
|
||||||
rows = await connector.execute_query(query, group_id=end_user_id)
|
rows = await connector.execute_query(query, end_user_id=end_user_id)
|
||||||
statements = [r.get("statement", "") for r in rows if r.get("statement")]
|
statements = [r.get("statement", "") for r in rows if r.get("statement")]
|
||||||
|
|
||||||
# 查询该用户的热门实体
|
# 查询该用户的热门实体
|
||||||
entity_query = (
|
entity_query = (
|
||||||
"MATCH (e:ExtractedEntity) "
|
"MATCH (e:ExtractedEntity) "
|
||||||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
||||||
"RETURN e.name AS name, count(e) AS frequency "
|
"RETURN e.name AS name, count(e) AS frequency "
|
||||||
"ORDER BY frequency DESC LIMIT 20"
|
"ORDER BY frequency DESC LIMIT 20"
|
||||||
)
|
)
|
||||||
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id)
|
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
|
||||||
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
|
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
|
||||||
|
|
||||||
await connector.close()
|
await connector.close()
|
||||||
@@ -853,14 +1064,14 @@ class MemoryAgentService:
|
|||||||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||||||
hot_tag_query = (
|
hot_tag_query = (
|
||||||
"MATCH (e:ExtractedEntity) "
|
"MATCH (e:ExtractedEntity) "
|
||||||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' "
|
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
|
||||||
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||||
"RETURN e.name AS name, count(e) AS frequency "
|
"RETURN e.name AS name, count(e) AS frequency "
|
||||||
"ORDER BY frequency DESC LIMIT 4"
|
"ORDER BY frequency DESC LIMIT 4"
|
||||||
)
|
)
|
||||||
hot_tag_rows = await connector.execute_query(
|
hot_tag_rows = await connector.execute_query(
|
||||||
hot_tag_query,
|
hot_tag_query,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
names_to_exclude=names_to_exclude
|
names_to_exclude=names_to_exclude
|
||||||
)
|
)
|
||||||
await connector.close()
|
await connector.close()
|
||||||
@@ -1006,6 +1217,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
"memory_config_id": memory_config_id
|
"memory_config_id": memory_config_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(188*'*')
|
||||||
|
print(result)
|
||||||
|
print(188 * '*')
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1033,7 +1248,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
"""
|
"""
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.data_config_model import DataConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
This service provides a thin layer that:
|
This service provides a thin layer that:
|
||||||
1. Validates end_user exists and belongs to the authorized workspace
|
1. Validates end_user exists and belongs to the authorized workspace
|
||||||
2. Maps end_user_id to group_id for memory operations
|
2. Maps end_user_id to end_user_id for memory operations
|
||||||
3. Delegates to MemoryAgentService for actual memory read/write operations
|
3. Delegates to MemoryAgentService for actual memory read/write operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as group_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Message content to store
|
message: Message content to store
|
||||||
config_id: Optional memory configuration ID
|
config_id: Optional memory configuration ID
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
@@ -133,13 +133,12 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as group_id for memory operations
|
# Use end_user_id as end_user_id for memory operations
|
||||||
group_id = end_user_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().write_memory(
|
result = await MemoryAgentService().write_memory(
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
message=message,
|
message=message,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
@@ -186,7 +185,7 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as group_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Query message
|
message: Query message
|
||||||
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||||
config_id: Optional memory configuration ID
|
config_id: Optional memory configuration ID
|
||||||
@@ -205,13 +204,13 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as group_id for memory operations
|
# Use end_user_id as end_user_id for memory operations
|
||||||
group_id = end_user_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().read_memory(
|
result = await MemoryAgentService().read_memory(
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
message=message,
|
message=message,
|
||||||
history=[],
|
history=[],
|
||||||
search_switch=search_switch,
|
search_switch=search_switch,
|
||||||
|
|||||||
@@ -326,7 +326,7 @@ class MemoryBaseService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_id: Summary节点的ID
|
summary_id: Summary节点的ID
|
||||||
end_user_id: 终端用户ID (group_id)
|
end_user_id: 终端用户ID (end_user_id)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
最大emotion_intensity对应的emotion_type,如果没有则返回None
|
最大emotion_intensity对应的emotion_type,如果没有则返回None
|
||||||
@@ -334,7 +334,7 @@ class MemoryBaseService:
|
|||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||||
WHERE stmt.emotion_type IS NOT NULL
|
WHERE stmt.emotion_type IS NOT NULL
|
||||||
AND stmt.emotion_intensity IS NOT NULL
|
AND stmt.emotion_intensity IS NOT NULL
|
||||||
@@ -347,7 +347,7 @@ class MemoryBaseService:
|
|||||||
result = await self.neo4j_connector.execute_query(
|
result = await self.neo4j_connector.execute_query(
|
||||||
query,
|
query,
|
||||||
summary_id=summary_id,
|
summary_id=summary_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result and len(result) > 0:
|
if result and len(result) > 0:
|
||||||
@@ -381,10 +381,10 @@ class MemoryBaseService:
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:MemorySummary)
|
MATCH (n:MemorySummary)
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN count(n) as count
|
RETURN count(n) as count
|
||||||
"""
|
"""
|
||||||
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id)
|
result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||||
else:
|
else:
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:MemorySummary)
|
MATCH (n:MemorySummary)
|
||||||
@@ -423,12 +423,12 @@ class MemoryBaseService:
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
semantic_query = """
|
semantic_query = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.group_id = $group_id AND e.is_explicit_memory = true
|
WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true
|
||||||
RETURN count(e) as count
|
RETURN count(e) as count
|
||||||
"""
|
"""
|
||||||
semantic_result = await self.neo4j_connector.execute_query(
|
semantic_result = await self.neo4j_connector.execute_query(
|
||||||
semantic_query,
|
semantic_query,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
semantic_query = """
|
semantic_query = """
|
||||||
@@ -519,7 +519,7 @@ class MemoryBaseService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
||||||
@@ -528,7 +528,7 @@ class MemoryBaseService:
|
|||||||
# 设置查询参数
|
# 设置查询参数
|
||||||
params = {'threshold': forgetting_threshold}
|
params = {'threshold': forgetting_threshold}
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
params['group_id'] = end_user_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
result = await self.neo4j_connector.execute_query(query, **params)
|
result = await self.neo4j_connector.execute_query(query, **params)
|
||||||
|
|||||||
@@ -125,7 +125,11 @@ class MemoryConfigService:
|
|||||||
try:
|
try:
|
||||||
validated_config_id = _validate_config_id(config_id)
|
validated_config_id = _validate_config_id(config_id)
|
||||||
|
|
||||||
|
# Step 1: Get config and workspace
|
||||||
|
db_query_start = time.time()
|
||||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||||
|
db_query_time = time.time() - db_query_start
|
||||||
|
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||||
if not result:
|
if not result:
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
config_logger.error(
|
config_logger.error(
|
||||||
@@ -144,16 +148,20 @@ class MemoryConfigService:
|
|||||||
|
|
||||||
memory_config, workspace = result
|
memory_config, workspace = result
|
||||||
|
|
||||||
# Validate embedding model
|
# Step 2: Validate embedding model (returns both UUID and name)
|
||||||
embedding_uuid = validate_embedding_model(
|
embed_start = time.time()
|
||||||
|
embedding_uuid, embedding_name = validate_embedding_model(
|
||||||
validated_config_id,
|
validated_config_id,
|
||||||
memory_config.embedding_id,
|
memory_config.embedding_id,
|
||||||
self.db,
|
self.db,
|
||||||
workspace.tenant_id,
|
workspace.tenant_id,
|
||||||
workspace.id,
|
workspace.id,
|
||||||
)
|
)
|
||||||
|
embed_time = time.time() - embed_start
|
||||||
|
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
|
||||||
|
|
||||||
# Resolve LLM model
|
# Step 3: Resolve LLM model
|
||||||
|
llm_start = time.time()
|
||||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||||
memory_config.llm_id,
|
memory_config.llm_id,
|
||||||
"llm",
|
"llm",
|
||||||
@@ -163,8 +171,11 @@ class MemoryConfigService:
|
|||||||
config_id=validated_config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
|
llm_time = time.time() - llm_start
|
||||||
|
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
|
||||||
|
|
||||||
# Resolve optional rerank model
|
# Step 4: Resolve optional rerank model
|
||||||
|
rerank_start = time.time()
|
||||||
rerank_uuid = None
|
rerank_uuid = None
|
||||||
rerank_name = None
|
rerank_name = None
|
||||||
if memory_config.rerank_id:
|
if memory_config.rerank_id:
|
||||||
@@ -177,16 +188,12 @@ class MemoryConfigService:
|
|||||||
config_id=validated_config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
|
rerank_time = time.time() - rerank_start
|
||||||
|
if memory_config.rerank_id:
|
||||||
|
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||||
|
|
||||||
# Get embedding model name
|
# Note: embedding_name is now returned from validate_embedding_model above
|
||||||
embedding_name, _ = validate_model_exists_and_active(
|
# No need for redundant query!
|
||||||
embedding_uuid,
|
|
||||||
"embedding",
|
|
||||||
self.db,
|
|
||||||
workspace.tenant_id,
|
|
||||||
config_id=validated_config_id,
|
|
||||||
workspace_id=workspace.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create immutable MemoryConfig object
|
# Create immutable MemoryConfig object
|
||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
|
|||||||
@@ -717,8 +717,8 @@ class MemoryInteraction:
|
|||||||
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
||||||
if ori_data!=[]:
|
if ori_data!=[]:
|
||||||
# name = ori_data[0]['name']
|
# name = ori_data[0]['name']
|
||||||
group_id = [i['group_id'] for i in ori_data][0]
|
end_user_id = [i['end_user_id'] for i in ori_data][0]
|
||||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id)
|
||||||
if not Space_User:
|
if not Space_User:
|
||||||
return []
|
return []
|
||||||
user_id=Space_User[0]['id']
|
user_id=Space_User[0]['id']
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_id: Summary节点的ID
|
summary_id: Summary节点的ID
|
||||||
end_user_id: 终端用户ID (group_id)
|
end_user_id: 终端用户ID (end_user_id)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(标题, 类型)元组,如果不存在则返回默认值
|
(标题, 类型)元组,如果不存在则返回默认值
|
||||||
@@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 查询Summary节点的name(作为title)和memory_type(作为type)
|
# 查询Summary节点的name(作为title)和memory_type(作为type)
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||||
RETURN s.name AS title, s.memory_type AS type
|
RETURN s.name AS title, s.memory_type AS type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await self.neo4j_connector.execute_query(
|
result = await self.neo4j_connector.execute_query(
|
||||||
query,
|
query,
|
||||||
summary_id=summary_id,
|
summary_id=summary_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result or len(result) == 0:
|
if not result or len(result) == 0:
|
||||||
@@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_id: Summary节点的ID
|
summary_id: Summary节点的ID
|
||||||
end_user_id: 终端用户ID (group_id)
|
end_user_id: 终端用户ID (end_user_id)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
前3个实体的name属性列表
|
前3个实体的name属性列表
|
||||||
@@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 按activation_value降序排序,返回前3个
|
# 按activation_value降序排序,返回前3个
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||||
MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity)
|
MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity)
|
||||||
WHERE entity.activation_value IS NOT NULL
|
WHERE entity.activation_value IS NOT NULL
|
||||||
@@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
result = await self.neo4j_connector.execute_query(
|
result = await self.neo4j_connector.execute_query(
|
||||||
query,
|
query,
|
||||||
summary_id=summary_id,
|
summary_id=summary_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提取实体名称
|
# 提取实体名称
|
||||||
@@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_id: Summary节点的ID
|
summary_id: Summary节点的ID
|
||||||
end_user_id: 终端用户ID (group_id)
|
end_user_id: 终端用户ID (end_user_id)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
所有Statement节点的statement属性内容列表
|
所有Statement节点的statement属性内容列表
|
||||||
@@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 查询Summary节点指向的所有Statement节点
|
# 查询Summary节点指向的所有Statement节点
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||||
WHERE stmt.statement IS NOT NULL AND stmt.statement <> ''
|
WHERE stmt.statement IS NOT NULL AND stmt.statement <> ''
|
||||||
RETURN stmt.statement AS statement
|
RETURN stmt.statement AS statement
|
||||||
@@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
result = await self.neo4j_connector.execute_query(
|
result = await self.neo4j_connector.execute_query(
|
||||||
query,
|
query,
|
||||||
summary_id=summary_id,
|
summary_id=summary_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提取statement内容
|
# 提取statement内容
|
||||||
@@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 1. 先查询所有情景记忆的总数(不受筛选条件限制)
|
# 1. 先查询所有情景记忆的总数(不受筛选条件限制)
|
||||||
total_all_query = """
|
total_all_query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE s.group_id = $group_id
|
WHERE s.end_user_id = $end_user_id
|
||||||
RETURN count(s) AS total_all
|
RETURN count(s) AS total_all
|
||||||
"""
|
"""
|
||||||
total_all_result = await self.neo4j_connector.execute_query(
|
total_all_result = await self.neo4j_connector.execute_query(
|
||||||
total_all_query,
|
total_all_query,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
total_all = total_all_result[0]["total_all"] if total_all_result else 0
|
total_all = total_all_result[0]["total_all"] if total_all_result else 0
|
||||||
|
|
||||||
@@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 3. 构建Cypher查询
|
# 3. 构建Cypher查询
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE s.group_id = $group_id
|
WHERE s.end_user_id = $end_user_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 添加时间范围过滤
|
# 添加时间范围过滤
|
||||||
@@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
ORDER BY s.created_at DESC
|
ORDER BY s.created_at DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {"group_id": end_user_id}
|
params = {"end_user_id": end_user_id}
|
||||||
if time_filter:
|
if time_filter:
|
||||||
params["time_filter"] = time_filter
|
params["time_filter"] = time_filter
|
||||||
if title_keyword:
|
if title_keyword:
|
||||||
@@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService):
|
|||||||
# 1. 查询指定的MemorySummary节点
|
# 1. 查询指定的MemorySummary节点
|
||||||
query = """
|
query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||||
RETURN elementId(s) AS id, s.created_at AS created_at
|
RETURN elementId(s) AS id, s.created_at AS created_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await self.neo4j_connector.execute_query(
|
result = await self.neo4j_connector.execute_query(
|
||||||
query,
|
query,
|
||||||
summary_id=summary_id,
|
summary_id=summary_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 如果节点不存在,返回错误
|
# 2. 如果节点不存在,返回错误
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
# ========== 1. 查询情景记忆(MemorySummary节点) ==========
|
# ========== 1. 查询情景记忆(MemorySummary节点) ==========
|
||||||
episodic_query = """
|
episodic_query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE s.group_id = $group_id
|
WHERE s.end_user_id = $end_user_id
|
||||||
RETURN elementId(s) AS id,
|
RETURN elementId(s) AS id,
|
||||||
s.name AS title,
|
s.name AS title,
|
||||||
s.content AS content,
|
s.content AS content,
|
||||||
@@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
|
|
||||||
episodic_result = await self.neo4j_connector.execute_query(
|
episodic_result = await self.neo4j_connector.execute_query(
|
||||||
episodic_query,
|
episodic_query,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理情景记忆数据
|
# 处理情景记忆数据
|
||||||
@@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
# ========== 2. 查询语义记忆(ExtractedEntity节点) ==========
|
# ========== 2. 查询语义记忆(ExtractedEntity节点) ==========
|
||||||
semantic_query = """
|
semantic_query = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.group_id = $group_id
|
WHERE e.end_user_id = $end_user_id
|
||||||
AND e.is_explicit_memory = true
|
AND e.is_explicit_memory = true
|
||||||
RETURN elementId(e) AS id,
|
RETURN elementId(e) AS id,
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
@@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
|
|
||||||
semantic_result = await self.neo4j_connector.execute_query(
|
semantic_result = await self.neo4j_connector.execute_query(
|
||||||
semantic_query,
|
semantic_query,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理语义记忆数据
|
# 处理语义记忆数据
|
||||||
@@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
# ========== 1. 先尝试查询情景记忆 ==========
|
# ========== 1. 先尝试查询情景记忆 ==========
|
||||||
episodic_query = """
|
episodic_query = """
|
||||||
MATCH (s:MemorySummary)
|
MATCH (s:MemorySummary)
|
||||||
WHERE elementId(s) = $memory_id AND s.group_id = $group_id
|
WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id
|
||||||
RETURN s.name AS title,
|
RETURN s.name AS title,
|
||||||
s.content AS content,
|
s.content AS content,
|
||||||
s.created_at AS created_at
|
s.created_at AS created_at
|
||||||
@@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
episodic_result = await self.neo4j_connector.execute_query(
|
episodic_result = await self.neo4j_connector.execute_query(
|
||||||
episodic_query,
|
episodic_query,
|
||||||
memory_id=memory_id,
|
memory_id=memory_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if episodic_result and len(episodic_result) > 0:
|
if episodic_result and len(episodic_result) > 0:
|
||||||
@@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
semantic_query = """
|
semantic_query = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE elementId(e) = $memory_id
|
WHERE elementId(e) = $memory_id
|
||||||
AND e.group_id = $group_id
|
AND e.end_user_id = $end_user_id
|
||||||
AND e.is_explicit_memory = true
|
AND e.is_explicit_memory = true
|
||||||
RETURN e.name AS name,
|
RETURN e.name AS name,
|
||||||
e.description AS core_definition,
|
e.description AS core_definition,
|
||||||
@@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService):
|
|||||||
semantic_result = await self.neo4j_connector.execute_query(
|
semantic_result = await self.neo4j_connector.execute_query(
|
||||||
semantic_query,
|
semantic_query,
|
||||||
memory_id=memory_id,
|
memory_id=memory_id,
|
||||||
group_id=end_user_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if semantic_result and len(semantic_result) > 0:
|
if semantic_result and len(semantic_result) > 0:
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class MemoryForgetService:
|
|||||||
async def _get_knowledge_stats(
|
async def _get_knowledge_stats(
|
||||||
self,
|
self,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
forgetting_threshold: float = 0.3
|
forgetting_threshold: float = 0.3
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -140,7 +140,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j 连接器
|
connector: Neo4j 连接器
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
forgetting_threshold: 遗忘阈值
|
forgetting_threshold: 遗忘阈值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -152,8 +152,8 @@ class MemoryForgetService:
|
|||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
WITH n,
|
WITH n,
|
||||||
@@ -172,8 +172,8 @@ class MemoryForgetService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {'threshold': forgetting_threshold}
|
params = {'threshold': forgetting_threshold}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await connector.execute_query(query, **params)
|
results = await connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ class MemoryForgetService:
|
|||||||
async def _get_pending_forgetting_nodes(
|
async def _get_pending_forgetting_nodes(
|
||||||
self,
|
self,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
forgetting_threshold: float,
|
forgetting_threshold: float,
|
||||||
min_days_since_access: int,
|
min_days_since_access: int,
|
||||||
limit: int = 20
|
limit: int = 20
|
||||||
@@ -212,7 +212,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j 连接器
|
connector: Neo4j 连接器
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
forgetting_threshold: 遗忘阈值
|
forgetting_threshold: 遗忘阈值
|
||||||
min_days_since_access: 最小未访问天数
|
min_days_since_access: 最小未访问天数
|
||||||
limit: 返回节点数量限制
|
limit: 返回节点数量限制
|
||||||
@@ -229,7 +229,7 @@ class MemoryForgetService:
|
|||||||
query = """
|
query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
AND n.group_id = $group_id
|
AND n.end_user_id = $end_user_id
|
||||||
AND n.activation_value IS NOT NULL
|
AND n.activation_value IS NOT NULL
|
||||||
AND n.activation_value < $threshold
|
AND n.activation_value < $threshold
|
||||||
AND n.last_access_time IS NOT NULL
|
AND n.last_access_time IS NOT NULL
|
||||||
@@ -250,7 +250,7 @@ class MemoryForgetService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'group_id': group_id,
|
'end_user_id': end_user_id,
|
||||||
'threshold': forgetting_threshold,
|
'threshold': forgetting_threshold,
|
||||||
'min_access_time_str': min_access_time_str,
|
'min_access_time_str': min_access_time_str,
|
||||||
'limit': limit
|
'limit': limit
|
||||||
@@ -291,7 +291,7 @@ class MemoryForgetService:
|
|||||||
async def trigger_forgetting_cycle(
|
async def trigger_forgetting_cycle(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
max_merge_batch_size: Optional[int] = None,
|
max_merge_batch_size: Optional[int] = None,
|
||||||
min_days_since_access: Optional[int] = None,
|
min_days_since_access: Optional[int] = None,
|
||||||
config_id: Optional[int] = None
|
config_id: Optional[int] = None
|
||||||
@@ -303,10 +303,10 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
group_id: 组ID(即终端用户ID,必填)
|
end_user_id: 组ID(即终端用户ID,必填)
|
||||||
max_merge_batch_size: 最大融合批次大小(可选)
|
max_merge_batch_size: 最大融合批次大小(可选)
|
||||||
min_days_since_access: 最小未访问天数(可选)
|
min_days_since_access: 最小未访问天数(可选)
|
||||||
config_id: 配置ID(必填,由控制器层通过 group_id 获取)
|
config_id: 配置ID(必填,由控制器层通过 end_user_id 获取)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 遗忘报告
|
dict: 遗忘报告
|
||||||
@@ -319,7 +319,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
# 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取)
|
# 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取)
|
||||||
report = await forgetting_scheduler.run_forgetting_cycle(
|
report = await forgetting_scheduler.run_forgetting_cycle(
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
max_merge_batch_size=max_merge_batch_size,
|
max_merge_batch_size=max_merge_batch_size,
|
||||||
min_days_since_access=min_days_since_access,
|
min_days_since_access=min_days_since_access,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
@@ -338,7 +338,7 @@ class MemoryForgetService:
|
|||||||
stats_query = """
|
stats_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||||
AND n.group_id = $group_id
|
AND n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
count(n) as total_nodes,
|
count(n) as total_nodes,
|
||||||
avg(n.activation_value) as average_activation,
|
avg(n.activation_value) as average_activation,
|
||||||
@@ -347,7 +347,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
stats_results = await connector.execute_query(
|
stats_results = await connector.execute_query(
|
||||||
stats_query,
|
stats_query,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
threshold=config['forgetting_threshold']
|
threshold=config['forgetting_threshold']
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -364,7 +364,7 @@ class MemoryForgetService:
|
|||||||
# 保存历史记录到数据库
|
# 保存历史记录到数据库
|
||||||
self.history_repository.create(
|
self.history_repository.create(
|
||||||
db=db,
|
db=db,
|
||||||
end_user_id=group_id,
|
end_user_id=end_user_id,
|
||||||
execution_time=execution_time,
|
execution_time=execution_time,
|
||||||
merged_count=report['merged_count'],
|
merged_count=report['merged_count'],
|
||||||
failed_count=report['failed_count'],
|
failed_count=report['failed_count'],
|
||||||
@@ -376,7 +376,7 @@ class MemoryForgetService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"已保存遗忘周期历史记录: end_user_id={group_id}, "
|
f"已保存遗忘周期历史记录: end_user_id={end_user_id}, "
|
||||||
f"merged_count={report['merged_count']}"
|
f"merged_count={report['merged_count']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -465,7 +465,7 @@ class MemoryForgetService:
|
|||||||
async def get_forgetting_stats(
|
async def get_forgetting_stats(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
config_id: Optional[int] = None
|
config_id: Optional[int] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -475,7 +475,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -493,8 +493,8 @@ class MemoryForgetService:
|
|||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
activation_query += " AND n.group_id = $group_id"
|
activation_query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
activation_query += """
|
activation_query += """
|
||||||
RETURN
|
RETURN
|
||||||
@@ -506,8 +506,8 @@ class MemoryForgetService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {'threshold': forgetting_threshold}
|
params = {'threshold': forgetting_threshold}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
activation_results = await connector.execute_query(activation_query, **params)
|
activation_results = await connector.execute_query(activation_query, **params)
|
||||||
|
|
||||||
@@ -539,8 +539,8 @@ class MemoryForgetService:
|
|||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
distribution_query += " AND n.group_id = $group_id"
|
distribution_query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
distribution_query += """
|
distribution_query += """
|
||||||
WITH n,
|
WITH n,
|
||||||
@@ -558,8 +558,8 @@ class MemoryForgetService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dist_params = {}
|
dist_params = {}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
dist_params['group_id'] = group_id
|
dist_params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
distribution_results = await connector.execute_query(distribution_query, **dist_params)
|
distribution_results = await connector.execute_query(distribution_query, **dist_params)
|
||||||
|
|
||||||
@@ -582,11 +582,11 @@ class MemoryForgetService:
|
|||||||
# 获取最近7个日期的历史趋势数据(每天取最后一次执行)
|
# 获取最近7个日期的历史趋势数据(每天取最后一次执行)
|
||||||
recent_trends = []
|
recent_trends = []
|
||||||
try:
|
try:
|
||||||
if group_id:
|
if end_user_id:
|
||||||
# 查询所有历史记录
|
# 查询所有历史记录
|
||||||
history_records = self.history_repository.get_recent_by_end_user(
|
history_records = self.history_repository.get_recent_by_end_user(
|
||||||
db=db,
|
db=db,
|
||||||
end_user_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 按日期分组(一天可能有多次执行,取最后一次)
|
# 按日期分组(一天可能有多次执行,取最后一次)
|
||||||
@@ -632,7 +632,7 @@ class MemoryForgetService:
|
|||||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||||
pending_nodes = []
|
pending_nodes = []
|
||||||
try:
|
try:
|
||||||
if group_id:
|
if end_user_id:
|
||||||
# 验证 min_days_since_access 配置值
|
# 验证 min_days_since_access 配置值
|
||||||
min_days = config.get('min_days_since_access')
|
min_days = config.get('min_days_since_access')
|
||||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||||
@@ -643,7 +643,7 @@ class MemoryForgetService:
|
|||||||
|
|
||||||
pending_nodes = await self._get_pending_forgetting_nodes(
|
pending_nodes = await self._get_pending_forgetting_nodes(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
forgetting_threshold=forgetting_threshold,
|
forgetting_threshold=forgetting_threshold,
|
||||||
min_days_since_access=int(min_days),
|
min_days_since_access=int(min_days),
|
||||||
limit=20
|
limit=20
|
||||||
|
|||||||
@@ -450,12 +450,12 @@ async def create_document_chunk(
|
|||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return success(data=chunk, msg="文档块创建成功")
|
||||||
|
|
||||||
async def write_rag(group_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组ID,用作文件标题
|
end_user_id: 组ID,用作文件标题
|
||||||
message: 消息内容
|
message: 消息内容
|
||||||
user_rag_memory_id: 知识库ID(必须是有效的UUID)
|
user_rag_memory_id: 知识库ID(必须是有效的UUID)
|
||||||
|
|
||||||
@@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
|||||||
db = next(db_gen)
|
db = next(db_gen)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_data = CustomTextFileCreate(title=group_id, content=message)
|
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||||
current_user = SimpleUser(user_rag_memory_id)
|
current_user = SimpleUser(user_rag_memory_id)
|
||||||
# 检查文档是否已存在
|
# 检查文档是否已存在
|
||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======',document)
|
print('======',document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
if document is not None:
|
if document is not None:
|
||||||
@@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
|||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
# 文档不存在,创建新文档
|
# 文档不存在,创建新文档
|
||||||
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}")
|
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||||||
result = await memory_konwledges_up(
|
result = await memory_konwledges_up(
|
||||||
kb_id=user_rag_memory_id,
|
kb_id=user_rag_memory_id,
|
||||||
parent_id=user_rag_memory_id,
|
parent_id=user_rag_memory_id,
|
||||||
@@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
|||||||
new_document_id = find_document_id_by_kb_and_filename(
|
new_document_id = find_document_id_by_kb_and_filename(
|
||||||
db=db,
|
db=db,
|
||||||
kb_id=user_rag_memory_id,
|
kb_id=user_rag_memory_id,
|
||||||
file_name=f"{group_id}.txt"
|
file_name=f"{end_user_id}.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_document_id:
|
if new_document_id:
|
||||||
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
|
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
|
||||||
else:
|
else:
|
||||||
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}")
|
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
# 确保数据库会话被关闭
|
# 确保数据库会话被关闭
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"config_name": config.config_name,
|
"config_name": config.config_name,
|
||||||
"config_desc": config.config_desc,
|
"config_desc": config.config_desc,
|
||||||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||||||
"group_id": config.group_id,
|
"end_user_id": config.end_user_id,
|
||||||
"user_id": config.user_id,
|
"user_id": config.user_id,
|
||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
@@ -391,7 +391,7 @@ _neo4j_connector = Neo4jConnector()
|
|||||||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_DIALOGUE,
|
DataConfigRepository.SEARCH_FOR_DIALOGUE,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
||||||
return data
|
return data
|
||||||
@@ -400,7 +400,7 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
|||||||
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_CHUNK,
|
DataConfigRepository.SEARCH_FOR_CHUNK,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
data = {"search_for": "chunk", "num": result[0]["num"]}
|
data = {"search_for": "chunk", "num": result[0]["num"]}
|
||||||
return data
|
return data
|
||||||
@@ -409,7 +409,7 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
|||||||
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_STATEMENT,
|
DataConfigRepository.SEARCH_FOR_STATEMENT,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
data = {"search_for": "statement", "num": result[0]["num"]}
|
data = {"search_for": "statement", "num": result[0]["num"]}
|
||||||
return data
|
return data
|
||||||
@@ -418,7 +418,7 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
|||||||
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_ENTITY,
|
DataConfigRepository.SEARCH_FOR_ENTITY,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
data = {"search_for": "entity", "num": result[0]["num"]}
|
data = {"search_for": "entity", "num": result[0]["num"]}
|
||||||
return data
|
return data
|
||||||
@@ -427,7 +427,7 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
|||||||
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_ALL,
|
DataConfigRepository.SEARCH_FOR_ALL,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查结果是否为空或长度不足
|
# 检查结果是否为空或长度不足
|
||||||
@@ -462,7 +462,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
|
|||||||
"""
|
"""
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_ALL,
|
DataConfigRepository.SEARCH_FOR_ALL,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查结果是否为空或长度不足
|
# 检查结果是否为空或长度不足
|
||||||
@@ -493,7 +493,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
|
|||||||
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_DETIALS,
|
DataConfigRepository.SEARCH_FOR_DETIALS,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -501,7 +501,7 @@ async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, An
|
|||||||
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_EDGES,
|
DataConfigRepository.SEARCH_FOR_EDGES,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -510,7 +510,7 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
|
|||||||
"""搜索所有实体之间的关系网络(group 维度)。"""
|
"""搜索所有实体之间的关系网络(group 维度)。"""
|
||||||
result = await _neo4j_connector.execute_query(
|
result = await _neo4j_connector.execute_query(
|
||||||
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
||||||
group_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
||||||
for item in result:
|
for item in result:
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def run_pilot_extraction(
|
|||||||
dialog = DialogData(
|
dialog = DialogData(
|
||||||
context=context,
|
context=context,
|
||||||
ref_id="pilot_dialog_1",
|
ref_id="pilot_dialog_1",
|
||||||
group_id=str(memory_config.workspace_id),
|
end_user_id=str(memory_config.workspace_id),
|
||||||
user_id=str(memory_config.tenant_id),
|
user_id=str(memory_config.tenant_id),
|
||||||
apply_id=str(memory_config.config_id),
|
apply_id=str(memory_config.config_id),
|
||||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||||
|
|||||||
@@ -155,10 +155,10 @@ class MemoryInsightHelper:
|
|||||||
"""
|
"""
|
||||||
query = """
|
query = """
|
||||||
MATCH (d:Dialogue)
|
MATCH (d:Dialogue)
|
||||||
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||||
RETURN d.created_at AS creation_time
|
RETURN d.created_at AS creation_time
|
||||||
"""
|
"""
|
||||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
|
||||||
|
|
||||||
if not records:
|
if not records:
|
||||||
return []
|
return []
|
||||||
@@ -211,17 +211,17 @@ class MemoryInsightHelper:
|
|||||||
async def get_social_connections(self) -> dict | None:
|
async def get_social_connections(self) -> dict | None:
|
||||||
"""Find the user with whom the most memories are shared."""
|
"""Find the user with whom the most memories are shared."""
|
||||||
query = """
|
query = """
|
||||||
MATCH (c1:Chunk {group_id: $group_id})
|
MATCH (c1:Chunk {end_user_id: $end_user_id})
|
||||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||||
WHERE common_statements > 0
|
WHERE common_statements > 0
|
||||||
RETURN other_user_id, common_statements
|
RETURN other_user_id, common_statements
|
||||||
ORDER BY common_statements DESC
|
ORDER BY common_statements DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
|
||||||
if not records or not records[0].get("other_user_id"):
|
if not records or not records[0].get("other_user_id"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ class MemoryInsightHelper:
|
|||||||
|
|
||||||
time_range_query = """
|
time_range_query = """
|
||||||
MATCH (c:Chunk)
|
MATCH (c:Chunk)
|
||||||
WHERE c.group_id IN [$user_id, $other_user_id]
|
WHERE c.end_user_id IN [$user_id, $other_user_id]
|
||||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||||
"""
|
"""
|
||||||
time_records = await self.neo4j_connector.execute_query(
|
time_records = await self.neo4j_connector.execute_query(
|
||||||
@@ -294,11 +294,11 @@ class UserSummaryHelper:
|
|||||||
"""Fetch recent statements authored by the user/group for context."""
|
"""Fetch recent statements authored by the user/group for context."""
|
||||||
query = (
|
query = (
|
||||||
"MATCH (s:Statement) "
|
"MATCH (s:Statement) "
|
||||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
"WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL "
|
||||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||||
"ORDER BY created_at DESC LIMIT $limit"
|
"ORDER BY created_at DESC LIMIT $limit"
|
||||||
)
|
)
|
||||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit)
|
||||||
records = []
|
records = []
|
||||||
for r in rows:
|
for r in rows:
|
||||||
try:
|
try:
|
||||||
@@ -357,6 +357,101 @@ class UserMemoryService:
|
|||||||
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
|
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def update_end_user_profile(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
end_user_id: str,
|
||||||
|
profile_update: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
更新终端用户的基本信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
end_user_id: 终端用户ID (UUID)
|
||||||
|
profile_update: 包含更新字段的 Pydantic 模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success": bool,
|
||||||
|
"data": dict, # 更新后的用户档案数据
|
||||||
|
"error": Optional[str]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 转换为UUID并查询用户
|
||||||
|
user_uuid = uuid.UUID(end_user_id)
|
||||||
|
repo = EndUserRepository(db)
|
||||||
|
end_user = repo.get_by_id(user_uuid)
|
||||||
|
|
||||||
|
if not end_user:
|
||||||
|
logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"data": None,
|
||||||
|
"error": "终端用户不存在"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取更新数据(排除 end_user_id 字段)
|
||||||
|
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||||
|
|
||||||
|
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
||||||
|
if 'hire_date' in update_data:
|
||||||
|
hire_date_timestamp = update_data['hire_date']
|
||||||
|
if hire_date_timestamp is not None:
|
||||||
|
from app.core.api_key_utils import timestamp_to_datetime
|
||||||
|
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||||
|
# 如果是 None,保持 None(允许清空)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(end_user, field, value)
|
||||||
|
|
||||||
|
# 更新时间戳
|
||||||
|
end_user.updated_at = datetime.now()
|
||||||
|
end_user.updatetime_profile = datetime.now()
|
||||||
|
|
||||||
|
# 提交更改
|
||||||
|
db.commit()
|
||||||
|
db.refresh(end_user)
|
||||||
|
|
||||||
|
# 构建响应数据
|
||||||
|
from app.schemas.end_user_schema import EndUserProfileResponse
|
||||||
|
profile_data = EndUserProfileResponse(
|
||||||
|
id=end_user.id,
|
||||||
|
other_name=end_user.other_name,
|
||||||
|
position=end_user.position,
|
||||||
|
department=end_user.department,
|
||||||
|
contact=end_user.contact,
|
||||||
|
phone=end_user.phone,
|
||||||
|
hire_date=end_user.hire_date,
|
||||||
|
updatetime_profile=end_user.updatetime_profile
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": self.convert_profile_to_dict_with_timestamp(profile_data),
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"data": None,
|
||||||
|
"error": "无效的用户ID格式"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"data": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
async def get_cached_memory_insight(
|
async def get_cached_memory_insight(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
@@ -1057,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
# 创建 UserSummaryHelper 实例
|
# 创建 UserSummaryHelper 实例
|
||||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1) 收集上下文数据
|
# 1) 收集上下文数据
|
||||||
@@ -1178,10 +1273,10 @@ async def analytics_node_statistics(
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_type})
|
MATCH (n:{node_type})
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN count(n) as count
|
RETURN count(n) as count
|
||||||
"""
|
"""
|
||||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||||
else:
|
else:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_type})
|
MATCH (n:{node_type})
|
||||||
@@ -1292,10 +1387,10 @@ async def analytics_memory_types(
|
|||||||
# 查询 Statement 节点数量
|
# 查询 Statement 节点数量
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN count(n) as count
|
RETURN count(n) as count
|
||||||
"""
|
"""
|
||||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||||
# 取三分之一作为隐性记忆数量
|
# 取三分之一作为隐性记忆数量
|
||||||
implicit_count = round(statement_count / 3)
|
implicit_count = round(statement_count / 3)
|
||||||
@@ -1409,7 +1504,7 @@ async def analytics_graph_data(
|
|||||||
包含节点、边和统计信息的字典
|
包含节点、边和统计信息的字典
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 1. 获取 group_id
|
# 1. 获取 end_user_id
|
||||||
user_uuid = uuid.UUID(end_user_id)
|
user_uuid = uuid.UUID(end_user_id)
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_user = repo.get_by_id(user_uuid)
|
end_user = repo.get_by_id(user_uuid)
|
||||||
@@ -1433,7 +1528,7 @@ async def analytics_graph_data(
|
|||||||
# 基于中心节点的扩展查询
|
# 基于中心节点的扩展查询
|
||||||
node_query = f"""
|
node_query = f"""
|
||||||
MATCH path = (center)-[*1..{depth}]-(connected)
|
MATCH path = (center)-[*1..{depth}]-(connected)
|
||||||
WHERE center.group_id = $group_id
|
WHERE center.end_user_id = $end_user_id
|
||||||
AND elementId(center) = $center_node_id
|
AND elementId(center) = $center_node_id
|
||||||
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
|
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
|
||||||
UNWIND all_nodes as n
|
UNWIND all_nodes as n
|
||||||
@@ -1444,7 +1539,7 @@ async def analytics_graph_data(
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
node_params = {
|
node_params = {
|
||||||
"group_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"center_node_id": center_node_id,
|
"center_node_id": center_node_id,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
}
|
}
|
||||||
@@ -1452,7 +1547,7 @@ async def analytics_graph_data(
|
|||||||
# 按节点类型过滤查询
|
# 按节点类型过滤查询
|
||||||
node_query = """
|
node_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
AND labels(n)[0] IN $node_types
|
AND labels(n)[0] IN $node_types
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) as id,
|
elementId(n) as id,
|
||||||
@@ -1461,7 +1556,7 @@ async def analytics_graph_data(
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
node_params = {
|
node_params = {
|
||||||
"group_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"node_types": node_types,
|
"node_types": node_types,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
}
|
}
|
||||||
@@ -1469,7 +1564,7 @@ async def analytics_graph_data(
|
|||||||
# 查询所有节点
|
# 查询所有节点
|
||||||
node_query = """
|
node_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) as id,
|
elementId(n) as id,
|
||||||
labels(n)[0] as label,
|
labels(n)[0] as label,
|
||||||
@@ -1477,7 +1572,7 @@ async def analytics_graph_data(
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
node_params = {
|
node_params = {
|
||||||
"group_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -382,12 +382,12 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||||
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
||||||
|
|
||||||
"""Celery task to process a read message via MemoryAgentService.
|
"""Celery task to process a read message via MemoryAgentService.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: User message to process
|
message: User message to process
|
||||||
history: Conversation history
|
history: Conversation history
|
||||||
search_switch: Search switch parameter
|
search_switch: Search switch parameter
|
||||||
@@ -408,7 +408,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -420,7 +420,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -448,7 +448,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"config_id": config_id,
|
"config_id": config_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
@@ -464,7 +464,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
return {
|
return {
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"error": detailed_error,
|
"error": detailed_error,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"config_id": config_id,
|
"config_id": config_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
@@ -472,11 +472,11 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||||
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
def write_message_task(self, end_user_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
config_id: Optional configuration ID
|
config_id: Optional configuration ID
|
||||||
|
|
||||||
@@ -489,7 +489,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
|||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}")
|
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Resolve config_id if None
|
# Resolve config_id if None
|
||||||
@@ -499,7 +499,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
|||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -512,7 +512,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
|||||||
try:
|
try:
|
||||||
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
|
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -547,7 +547,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
|||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"config_id": config_id,
|
"config_id": config_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
@@ -566,7 +566,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
|||||||
return {
|
return {
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"error": detailed_error,
|
"error": detailed_error,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"config_id": config_id,
|
"config_id": config_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
@@ -612,7 +612,7 @@ def check_read_service_task() -> Dict[str, str]:
|
|||||||
payload = {
|
payload = {
|
||||||
"user_id": "健康检查",
|
"user_id": "健康检查",
|
||||||
"apply_id": "健康检查",
|
"apply_id": "健康检查",
|
||||||
"group_id": "健康检查",
|
"end_user_id": "健康检查",
|
||||||
"message": "你好",
|
"message": "你好",
|
||||||
"history": [],
|
"history": [],
|
||||||
"search_switch": "2",
|
"search_switch": "2",
|
||||||
@@ -1112,7 +1112,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
|
|||||||
# 运行遗忘周期
|
# 运行遗忘周期
|
||||||
report = await forget_service.trigger_forgetting(
|
report = await forget_service.trigger_forgetting(
|
||||||
db=db,
|
db=db,
|
||||||
group_id=None, # 处理所有组
|
end_user_id=None, # 处理所有组
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,6 @@ services:
|
|||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
|
||||||
- SERVER_IP=0.0.0.0
|
|
||||||
# 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位
|
|
||||||
# - MCP_SERVER_URL=
|
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
@@ -19,20 +15,53 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- default
|
- default
|
||||||
- celery
|
- celery
|
||||||
|
depends_on:
|
||||||
|
- worker-memory
|
||||||
|
- worker-document
|
||||||
|
|
||||||
# Celery worker
|
# Memory worker - Memory read/write tasks (threads pool for asyncio)
|
||||||
worker:
|
worker-memory:
|
||||||
image: redbear-mem-open:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: worker
|
container_name: worker-memory
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
command: celery -A app.celery_worker.celery_app worker --loglevel=info
|
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- celery
|
- celery
|
||||||
|
|
||||||
|
# Document worker - Document parsing tasks (prefork for CPU-bound)
|
||||||
|
worker-document:
|
||||||
|
image: redbear-mem-open:latest
|
||||||
|
container_name: worker-document
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ./files:/files
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- celery
|
||||||
|
|
||||||
|
# Celery Beat - scheduler
|
||||||
|
beat:
|
||||||
|
image: redbear-mem-open:latest
|
||||||
|
container_name: celery-beat
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ./files:/files
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
command: celery -A app.celery_worker.celery_app beat --loglevel=info
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- celery
|
||||||
|
depends_on:
|
||||||
|
- worker-memory
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
celery:
|
celery:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"bcrypt==5.0.0",
|
"bcrypt==5.0.0",
|
||||||
"billiard==4.2.2",
|
"billiard==4.2.2",
|
||||||
"celery==5.5.3",
|
"celery==5.5.3",
|
||||||
|
"flower==2.0.1",
|
||||||
"cffi==2.0.0",
|
"cffi==2.0.0",
|
||||||
"click==8.3.0",
|
"click==8.3.0",
|
||||||
"click-didyoumean==0.3.1",
|
"click-didyoumean==0.3.1",
|
||||||
@@ -138,6 +139,7 @@ dependencies = [
|
|||||||
"python-calamine>=0.4.0",
|
"python-calamine>=0.4.0",
|
||||||
"xlrd==2.0.2",
|
"xlrd==2.0.2",
|
||||||
"deprecated>=1.3.1",
|
"deprecated>=1.3.1",
|
||||||
|
"flower>=2.0.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ async-timeout==5.0.1
|
|||||||
bcrypt==5.0.0
|
bcrypt==5.0.0
|
||||||
billiard==4.2.2
|
billiard==4.2.2
|
||||||
celery==5.5.3
|
celery==5.5.3
|
||||||
|
flower==2.0.1
|
||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
click==8.3.0
|
click==8.3.0
|
||||||
click-didyoumean==0.3.1
|
click-didyoumean==0.3.1
|
||||||
|
|||||||
Reference in New Issue
Block a user