Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
lixinyue
2026-01-21 19:16:04 +08:00
37 changed files with 1014 additions and 827 deletions

View File

@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
## License ## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file. This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Acknowledgements & Community ## Community & Support
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines. Join our community to ask questions, share your work, and connect with fellow developers.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6)
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -1,4 +1,5 @@
import os import os
import platform
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
@@ -14,27 +15,12 @@ celery_app = Celery(
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
) )
# 配置使用本地队列,避免与远程 worker 冲突 # Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'localhost_test_wyl' celery_app.conf.task_default_queue = 'memory_tasks'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置 # macOS 兼容性配置
import platform if platform.system() == 'Darwin':
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置 # Celery 配置
celery_app.conf.update( celery_app.conf.update(
@@ -52,36 +38,47 @@ celery_app.conf.update(
task_ignore_result=False, task_ignore_result=False,
# 超时设置 # 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时 task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时 task_soft_time_limit=1500, # 25分钟软超时
# Worker 设置 - 针对 macOS 优化 # Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# 结果过期时间 # 结果过期时间
result_expires=3600, # 结果保存 1 小时 result_expires=3600, # 结果保存1小时
# 任务确认设置 # 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失 task_acks_late=True,
worker_disable_rate_limits=True, # 禁用速率限制 task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# 任务路由(可选,用于不同队列) # FLower setting
# task_routes={ worker_send_task_events=True,
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, task_send_sent_event=True,
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, # task routing
# 'tasks.process_item': {'queue': 'default'}, task_routes={
# }, # Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
},
) )
# 自动发现任务模块 # 自动发现任务模块
celery_app.autodiscover_tasks(['app']) celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks # Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
# 构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": { "run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task", "task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule, "schedule": workspace_reflection_schedule,

View File

@@ -9,6 +9,8 @@ from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey from app.models import ModelApiKey
from app.models.user_model import User from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
@@ -160,9 +162,12 @@ async def write_server(
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
user_input.group_id, user_input.group_id,
user_input.message, messages_list, # 传递结构化消息列表
config_id, config_id,
db, db,
storage_type, storage_type,
@@ -219,9 +224,12 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.write_message", "app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id] args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
@@ -285,6 +293,19 @@ async def read_server(
storage_type, storage_type,
user_rag_memory_id user_rag_memory_id
) )
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -564,8 +585,23 @@ async def status_type(
""" """
api_logger.info(f"Status type check requested for group {user_input.group_id}") api_logger.info(f"Status type check requested for group {user_input.group_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type( result = await memory_agent_service.classify_message_type(
user_input.message, last_user_message,
user_input.config_id, user_input.config_id,
db db
) )
@@ -661,7 +697,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取用户详情,包含: 获取工作空间下Popular Memory Tags,包含:
- name: 用户名字(直接使用 end_user_id - name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结 - tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签 - hot_tags: 4个热门记忆标签

View File

@@ -5,7 +5,6 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service from app.services import memory_dashboard_service, memory_storage_service, workspace_service
@@ -40,54 +39,7 @@ def get_workspace_total_end_users(
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}") api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功") return success(data=total_end_users, msg="用户数量获取成功")
@router.post("/update/end_users", response_model=ApiResponse)
async def update_workspace_end_users(
user_input: End_User_Information,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新工作空间的宿主信息
"""
username = user_input.end_user_name # 要更新的用户名
end_user_input_id = user_input.id # 宿主ID
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
try:
# 导入更新函数
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
# 转换 end_user_id 为 UUID 类型
end_user_uuid = uuid.UUID(end_user_input_id)
# 直接更新数据库中的 other_name 字段
updated_count = update_end_user_other_name(
db=db,
end_user_id=end_user_uuid,
other_name=username
)
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
return success(
data={
"updated_count": updated_count,
"end_user_id": end_user_input_id,
"updated_other_name": username
},
msg=f"成功更新 {updated_count} 个宿主的信息"
)
except Exception as e:
api_logger.error(f"更新宿主信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新宿主信息失败: {str(e)}"
)

View File

@@ -28,7 +28,6 @@ from app.services.memory_storage_service import (
search_dialogue, search_dialogue,
search_edges, search_edges,
search_entity, search_entity,
search_entity_graph,
search_statement, search_statement,
) )
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
@@ -412,21 +411,7 @@ async def search_entity_edges(
api_logger.error(f"Search edges failed: {str(e)}") api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)

View File

@@ -351,12 +351,11 @@ async def update_end_user_profile(
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。 所有字段都是可选的,只更新提供的字段。
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id end_user_id = profile_update.end_user_id
# 检查用户是否已选择工作空间 # 验证工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
@@ -366,57 +365,24 @@ async def update_end_user_profile(
f"workspace={workspace_id}" f"workspace={workspace_id}"
) )
try: # 调用 Service 层处理业务逻辑
# 查询终端用户 result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user: if result["success"]:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") return success(data=result["data"], msg="更新成功")
else:
# 更新字段(只更新提供的字段,排除 end_user_id error_msg = result["error"]
# 允许 None 值来重置字段(如 hire_date api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 根据错误类型映射到合适的业务错误码
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime if error_msg == "终端用户不存在":
if 'hire_date' in update_data: return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
hire_date_timestamp = update_data['hire_date'] elif error_msg == "无效的用户ID格式":
if hire_date_timestamp is not None: return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) else:
# 如果是 None保持 None允许清空 # 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse) @router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh", async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",

View File

@@ -145,44 +145,98 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
async def term_memory_save(self,messages,end_user_end,aimessages): # TODO 乐力齐 - 累积多组对话批量写入功能已禁用
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j''' # async def term_memory_save(self,messages,end_user_end,aimessages):
end_user_end=f"Term_{end_user_end}" # '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
print(messages) # end_user_end=f"Term_{end_user_end}"
print(aimessages) # print(messages)
session_id = store.save_session( # print(aimessages)
userid=end_user_end, # session_id = store.save_session(
messages=messages, # userid=end_user_end,
apply_id=end_user_end, # messages=messages,
group_id=end_user_end, # apply_id=end_user_end,
aimessages=aimessages # group_id=end_user_end,
) # aimessages=aimessages
store.delete_duplicate_sessions() # )
# logger.info(f'Redis_Agent:{end_user_end};{session_id}') # store.delete_duplicate_sessions()
return session_id # # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
async def term_memory_redis_read(self,end_user_end): # return session_id
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) # TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# logger.info(f'Redis_Agent:{end_user_end};{history}') # async def term_memory_redis_read(self,end_user_end):
messagss_list=[] # end_user_end = f"Term_{end_user_end}"
retrieved_content=[] # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
for messages in history: # # logger.info(f'Redis_Agent:{end_user_end};{history}')
query = messages.get("Query") # messagss_list=[]
aimessages = messages.get("Answer") # retrieved_content=[]
messagss_list.append(f'用户:{query}。AI回复:{aimessages}') # for messages in history:
retrieved_content.append({query: aimessages}) # query = messages.get("Query")
return messagss_list,retrieved_content # aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): """
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
if storage_type == "rag": if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id) # RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else: else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, # Neo4j 模式:使用结构化消息列表
user_rag_memory_id) structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id)) write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def chat( async def chat(
self, self,
@@ -227,29 +281,30 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown" actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id) # # 长期记忆写入(
history_term_memory = history_term_memory_result[0] # await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
db_for_memory = next(get_db()) # # 注意:不在这里写入用户消息,等 AI 回复后一起写入
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try: try:
# 准备消息列表 # 准备消息列表
messages = self._prepare_messages(message, history, context) messages = self._prepare_messages(message, history, context)
@@ -277,8 +332,10 @@ class LangChainAgent:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag: if memory_flag:
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id) # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.term_memory_save(message_chat,end_user_id,content) await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -346,27 +403,27 @@ class LangChainAgent:
db.close() db.close()
except Exception as e: except Exception as e:
logger.warning(f"Failed to get db session: {e}") logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id) # 注意:不在这里写入用户消息,等 AI 回复后一起写入
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try: try:
# 准备消息列表 # 准备消息列表
messages = self._prepare_messages(message, history, context) messages = self._prepare_messages(message, history, context)
@@ -418,8 +475,10 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
if memory_flag: if memory_flag:
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id) # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.term_memory_save(message_chat, end_user_id, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -18,16 +18,19 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin): class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类""" """问题处理节点服务类"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
problem_service = ProblemNodeService() problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState: async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """问题分解节点"""
# 从状态中获取数据 # 从状态中获取数据
@@ -36,10 +39,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template( system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2', template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem', operation_name='split_the_problem',
@@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
sentence=content, sentence=content,
json_schema=json_schema json_schema=json_schema
) )
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
structured = await problem_service.call_llm_structured( structured = await problem_service.call_llm_structured(
@@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
response_model=ProblemExtensionResponse, response_model=ProblemExtensionResponse,
fallback_value=[] fallback_value=[]
) )
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应 # 验证结构化响应
if not structured or not hasattr(structured, 'root'): if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
@@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
[item.model_dump() for item in structured.root], [item.model_dump() for item in structured.root],
ensure_ascii=False ensure_ascii=False
) )
split_result_dict = [] split_result_dict = []
for index, item in enumerate(json.loads(split_result)): for index, item in enumerate(json.loads(split_result)):
split_data = { split_data = {
"id": f"Q{index+1}", "id": f"Q{index + 1}",
"question": item['extended_question'], "question": item['extended_question'],
"type": item['type'], "type": item['type'],
"reason": item['reason'] "reason": item['reason']
} }
split_result_dict.append(split_data) split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = { result = {
@@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"original_query": content "original_query": content
} }
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Split_The_Problem failed: {e}", f"Split_The_Problem failed: {e}",
exc_info=True exc_info=True
) )
# 提供更详细的错误信息 # 提供更详细的错误信息
error_details = { error_details = {
"error_type": type(e).__name__, "error_type": type(e).__name__,
@@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"content_length": len(content), "content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None "llm_model_id": memory_config.llm_model_id if memory_config else None
} }
logger.error(f"Split_The_Problem error details: {error_details}") logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果 # 创建默认的空结果
result = { result = {
"context": json.dumps([], ensure_ascii=False), "context": json.dumps([], ensure_ascii=False),
@@ -126,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error": error_details "error": error_details
} }
} }
# 返回更新后的状态包含spit_context字段 # 返回更新后的状态包含spit_context字段
return {"spit_data": result} return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState: async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点""" """问题扩展节点"""
# 获取原始数据和分解结果 # 获取原始数据和分解结果
@@ -153,10 +157,10 @@ async def Problem_Extension(state: ReadState) -> ReadState:
data = [] data = []
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template( system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2', template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension', operation_name='problem_extension',
@@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState:
} }
} }
return {"problem_extension": result} return {"problem_extension": result}

View File

@@ -4,12 +4,11 @@ import os
import time import time
from app.core.logging_config import get_agent_logger, log_time from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
SummaryResponse, SummaryResponse,
) )
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
search_params = { search_params = {
"group_id": group_id, "group_id": group_id,
"question": data, "question": data,
"return_raw_results": True "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
} }
try: try:

View File

@@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState:
Write data to the database/file system. Write data to the database/file system.
Args: Args:
ctx: FastMCP context for dependency injection state: WriteState containing messages, group_id, and memory_config
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns: Returns:
dict: Contains 'status', 'saved_to', and 'data' fields dict: Contains 'write_result' with status and data fields
""" """
content=state.get('data','') messages = state.get('messages', [])
group_id=state.get('group_id','') group_id = state.get('group_id', '')
memory_config=state.get('memory_config', '') memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try: try:
result=await write( result = await write(
content=content, messages=structured_messages,
user_id=group_id, user_id=group_id,
apply_id=group_id, apply_id=group_id,
group_id=group_id, group_id=group_id,
@@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState:
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
write_result= { write_result = {
"status": "success", "status": "success",
"data": content, "data": structured_messages,
"config_id": memory_config.config_id, "config_id": memory_config.config_id,
"config_name": memory_config.config_name, "config_name": memory_config.config_name,
} }
return {"write_result":write_result} return {"write_result": write_result}
except Exception as e: except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True) logger.error(f"Data_write failed: {e}", exc_info=True)
write_result= { write_result = {
"status": "error", "status": "error",
"message": str(e), "message": str(e),
} }

View File

@@ -59,7 +59,6 @@ async def make_read_graph():
workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END) workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)

View File

@@ -14,7 +14,6 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -27,18 +26,12 @@ async def make_write_graph():
""" """
Create a write graph workflow for memory operations. Create a write graph workflow for memory operations.
Args: The workflow directly processes messages from the initial state
user_id: User identifier and saves them to Neo4j storage.
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
""" """
workflow = StateGraph(WriteState) workflow = StateGraph(WriteState)
workflow.add_node("content_input", content_input_write)
workflow.add_node("save_neo4j", write_node) workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input") workflow.add_edge(START, "save_neo4j")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END) workflow.add_edge("save_neo4j", END)
graph = workflow.compile() graph = workflow.compile()

View File

@@ -162,7 +162,7 @@ class OptimizedLLMService:
return fallback_value return fallback_value
elif isinstance(fallback_value, dict): elif isinstance(fallback_value, dict):
return response_model(**fallback_value) return response_model(**fallback_value)
# 尝试创建空的响应模型 # 尝试创建空的响应模型
if hasattr(response_model, 'root'): if hasattr(response_model, 'root'):
# RootModel类型 # RootModel类型
@@ -170,7 +170,7 @@ class OptimizedLLMService:
else: else:
# 普通BaseModel类型 # 普通BaseModel类型
return response_model() return response_model()
except Exception as e: except Exception as e:
logger.error(f"创建降级响应失败: {e}") logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略 # 最后的降级策略

View File

@@ -12,32 +12,49 @@ async def get_chunked_dialogs(
group_id: str = "group_1", group_id: str = "group_1",
user_id: str = "user1", user_id: str = "user1",
apply_id: str = "applyid", apply_id: str = "applyid",
content: str = "这是用户的输入", messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "wyl_20251027",
config_id: str = None config_id: str = None
) -> List[DialogData]: ) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy. """Generate chunks from structured messages using the specified chunker strategy.
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier group_id: Group identifier
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
content: Dialog content messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier ref_id: Reference identifier
config_id: Configuration ID for processing config_id: Configuration ID for processing
Returns: Returns:
List of DialogData objects with generated chunks for each test entry List of DialogData objects with generated chunks
""" """
dialog_data_list = [] from app.core.logging_config import get_agent_logger
messages = [] logger = get_agent_logger(__name__)
messages.append(ConversationMessage(role="用户", msg=content)) if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
# Create DialogData
conversation_context = ConversationContext(msgs=messages) conversation_messages = []
# Create DialogData with group_id based on the entry's id for uniqueness
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=ref_id, ref_id=ref_id,
@@ -46,25 +63,11 @@ async def get_chunked_dialogs(
apply_id=apply_id, apply_id=apply_id,
config_id=config_id config_id=config_id
) )
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy) chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data) extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
dialog_data_list.append(dialog_data) return [dialog_data]
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -29,25 +29,22 @@ logger = get_agent_logger(__name__)
async def write( async def write(
content: str,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str, group_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027", ref_id: str = "wyl20251027",
) -> None: ) -> None:
""" """
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args: Args:
content: Dialogue content to process
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier group_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to "wyl20251027"
""" """
# Extract config values # Extract config values
@@ -89,7 +86,7 @@ async def write(
group_id=group_id, group_id=group_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
content=content, messages=messages,
ref_id=ref_id, ref_id=ref_id,
config_id=config_id, config_id=config_id,
) )

View File

@@ -4,6 +4,7 @@ import os
import asyncio import asyncio
import json import json
import numpy as np import numpy as np
import logging
# Fix tokenizer parallelism warning # Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
try: try:
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception: except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any OpenAIClient = Any
# Initialize logger
logger = logging.getLogger(__name__)
class LLMChunker: class LLMChunker:
"""基于LLM的智能分块策略""" """LLM-based intelligent chunking strategy"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client self.llm_client = llm_client
self.chunk_size = chunk_size self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]: async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f""" prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
请以JSON格式返回结果包含chunks数组每个chunk有text字段。 Return results in JSON format with a chunks array, each chunk having a text field.
文本内容: Text content:
{text[:5000]} {text[:5000]}
""" """
messages = [ messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, {"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
] ]
@@ -171,8 +173,6 @@ class ChunkerClient:
base_chunk_size=self.chunk_size, base_chunk_size=self.chunk_size,
) )
elif chunker_config.chunker_strategy == "SentenceChunker": elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker( self.chunker = SentenceChunker(
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap, chunk_overlap=self.chunk_overlap,
@@ -186,100 +186,93 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData): async def generate_chunks(self, dialogue: DialogData):
""" """
生成分块,支持异步操作 Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
""" """
try: # Validate dialogue has messages
# 预处理文本:确保对话标记格式统一 if not dialogue.context or not dialogue.context.msgs:
content = dialogue.content raise ValueError(
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号 f"Dialogue {dialogue.ref_id} has no messages. "
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 f"Cannot generate chunks from empty dialogue."
)
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器 dialogue.chunks = []
chunks = self.chunker(content)
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
raise ValueError(
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
try:
sub_chunks = self.chunker(msg_content)
except Exception as e:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
metadata={
"message_index": msg_idx,
"message_role": msg.role,
"sub_chunk_index": idx,
"total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
dialogue.chunks.append(chunk)
else: else:
# 异步分块器如LLMChunker # 消息不长,直接作为一个 chunk
chunks = await self.chunker(content) chunk = Chunk(
content=f"{msg.role}: {msg_content}",
# 过滤空块和过小的块 speaker=msg.role, # 直接继承角色
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
metadata={ metadata={
"start_index": getattr(c, "start_index", None), "message_index": msg_idx,
"end_index": getattr(c, "end_index", None), "message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
) )
for c in valid_chunks dialogue.chunks.append(chunk)
]
return dialogue # Validate we generated at least one chunk
if not dialogue.chunks:
except Exception as e: raise ValueError(
print(f"分块失败: {e}") f"No valid chunks generated for dialogue {dialogue.ref_id}. "
f"All messages were either empty or too short. "
# 改进的后备方案:尝试按对话回合分割 f"Messages count: {len(dialogue.context.msgs)}"
try: )
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' return dialogue
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict: def evaluate_chunking(self, dialogue: DialogData) -> dict:
""" """Evaluate chunking quality."""
评估分块质量
"""
if not getattr(dialogue, 'chunks', None): if not getattr(dialogue, 'chunks', None):
return {} return {}
@@ -304,11 +297,8 @@ class ChunkerClient:
return metrics return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str): def save_chunking_results(self, dialogue: DialogData, output_path: str):
""" """Save chunking results to file with strategy name in filename."""
保存分块结果到文件,文件名包含策略名称
"""
strategy_name = self.chunker_config.chunker_strategy strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path) base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}" strategy_output_path = f"{base_name}_{strategy_name}{ext}"

View File

@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
config["callbacks"] = [self.langfuse_handler] config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config) response = await chain.ainvoke({"messages": messages}, config=config)
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
return response return response
except Exception as e: except Exception as e:
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
config=config config=config
) )
logger.debug(f"使用 PydanticOutputParser 解析成功")
return parsed return parsed
except Exception as e: except Exception as e:
logger.warning( logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
)
# 方法 2: 使用 LangChain 的 with_structured_output # 方法 2: 使用 LangChain 的 with_structured_output
template = """{question}""" template = """{question}"""
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
# 验证并返回结果 # 验证并返回结果
try: try:
return response_model.model_validate(parsed) result = response_model.model_validate(parsed)
return result
except Exception: except Exception:
# 如果已经是 Pydantic 实例,直接返回 # 如果已经是 Pydantic 实例,直接返回
if hasattr(parsed, "model_dump"): if hasattr(parsed, "model_dump"):
return parsed return parsed
# 尝试从 JSON 解析 # 尝试从 JSON 解析
return response_model.model_validate_json(json.dumps(parsed)) result = response_model.model_validate_json(json.dumps(parsed))
return result
else:
logger.warning("with_structured_output 方法不可用")
except Exception as e: except Exception as e:
logger.error(f"结构化输出失败: {e}") logger.error(f"结构化输出失败: {e}")

View File

@@ -224,6 +224,7 @@ class StatementNode(Node):
chunk_id: ID of the parent chunk this statement belongs to chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology) stmt_type: Type of the statement (from ontology)
statement: The actual statement text content statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
emotion_target: Optional emotion target (person or object name) emotion_target: Optional emotion target (person or object name)
emotion_subject: Optional emotion subject (self/other/object) emotion_subject: Optional emotion subject (self/other/object)
@@ -249,6 +250,12 @@ class StatementNode(Node):
stmt_type: str = Field(..., description="Type of the statement") stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content") statement: str = Field(..., description="The statement text content")
# Speaker identification
speaker: Optional[str] = Field(
None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
)
# Emotion fields (ordered as requested, emotion_intensity first for display) # Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field( emotion_intensity: Optional[float] = Field(
None, None,

View File

@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
"""Represents a single message in a conversation. """Represents a single message in a conversation.
Attributes: Attributes:
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant) role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
msg: Text content of the message msg: Text content of the message
""" """
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').") role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.") msg: str = Field(..., description="The text content of the message.")
@@ -57,6 +57,7 @@ class Statement(BaseModel):
chunk_id: ID of the parent chunk this statement belongs to chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy group_id: Optional group ID for multi-tenancy
statement: The actual statement text content statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement statement_embedding: Optional embedding vector for the statement
stmt_type: Type of the statement (from ontology) stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement temporal_info: Temporal information extracted from the statement
@@ -74,6 +75,7 @@ class Statement(BaseModel):
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.") statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
stmt_type: StatementType = Field(..., description="The type of the statement.") stmt_type: StatementType = Field(..., description="The type of the statement.")
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.") temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
Attributes: Attributes:
id: Unique identifier for the chunk id: Unique identifier for the chunk
text: List of messages in the chunk
content: The content of the chunk as a formatted string content: The content of the chunk as a formatted string
speaker: The speaker/role for this chunk (user/assistant)
statements: List of statements extracted from this chunk statements: List of statements extracted from this chunk
chunk_embedding: Optional embedding vector for the chunk chunk_embedding: Optional embedding vector for the chunk
metadata: Additional metadata as key-value pairs metadata: Additional metadata as key-value pairs
""" """
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
content: str = Field(..., description="The content of the chunk as a string.") content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod @classmethod
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None): def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a list of messages. """Create a chunk from a single message (1 Message = 1 Chunk).
Args: Args:
messages: List of conversation messages message: Single conversation message
metadata: Optional metadata dictionary metadata: Optional metadata dictionary
Returns: Returns:
Chunk instance with formatted content Chunk instance with speaker directly from message.role
""" """
if metadata is None: return cls(
metadata = {} content=f"{message.role}: {message.msg}",
# Generate content from messages speaker=message.role,
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages]) metadata=metadata or {}
return cls(text=messages, content=content, metadata=metadata) )
class DialogData(BaseModel): class DialogData(BaseModel):
"""Represents the complete data structure for a dialog record. """Represents the complete data structure for a dialog record.

View File

@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取情绪信息(优化版:全局陈述句级并行) 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
Args: Args:
dialog_data_list: 对话数据列表 dialog_data_list: 对话数据列表
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
Returns: Returns:
情绪信息映射列表,每个对话对应一个字典 情绪信息映射列表,每个对话对应一个字典
""" """
logger.info("开始情绪信息提取(全局陈述句级并行") logger.info("开始情绪信息提取(仅处理用户消息")
# 收集所有陈述句及其配置 # 收集所有陈述句及其配置
all_statements = [] all_statements = []
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
if not data_config or not data_config.emotion_enabled: if not data_config or not data_config.emotion_enabled:
logger.info("情绪提取未启用,跳过") logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
# 收集所有陈述句(只收集 speaker 为 "user" 的)
total_statements = 0
filtered_statements = 0
# 收集所有陈述句
for d_idx, dialog in enumerate(dialog_data_list): for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks: for chunk in dialog.chunks:
for statement in chunk.statements: for statement in chunk.statements:
all_statements.append((statement, data_config)) total_statements += 1
statement_metadata.append((d_idx, statement.id)) # 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
# 初始化情绪提取服务 # 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService from app.services.emotion_extraction_service import EmotionExtractionService
@@ -1033,6 +1040,7 @@ class ExtractionOrchestrator:
apply_id=dialog_data.apply_id, apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding, statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,

View File

@@ -22,12 +22,12 @@ class DialogueChunker:
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
""" """
self.chunker_strategy = chunker_strategy self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy) chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# 对于 LLMChunker需要传入 llm_client
if self.chunker_config.chunker_strategy == "LLMChunker": if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client) self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else: else:
@@ -41,29 +41,19 @@ class DialogueChunker:
Returns: Returns:
A list of Chunk objects A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
""" """
result_dialogue = await self.chunker_client.generate_chunks(dialogue) result_dialogue = await self.chunker_client.generate_chunks(dialogue)
# Defensive fallback: ensure at least one chunk is returned for non-empty content chunks = result_dialogue.chunks
try:
chunks = result_dialogue.chunks
except Exception:
chunks = []
if not chunks or len(chunks) == 0: if not chunks or len(chunks) == 0:
# If the dialogue has content, return a single fallback chunk built from messages raise ValueError(
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "") f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
if content_str and len(content_str.strip()) > 0: f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
fallback_chunk = Chunk.from_messages( f"Strategy: {self.chunker_config.chunker_strategy}"
dialogue.context.msgs, )
metadata={
"fallback": "single_chunk",
"chunker_strategy": self.chunker_config.chunker_strategy,
"source": "DialogueChunkerFallback",
},
)
return [fallback_chunk]
# No content: return empty list
return []
return chunks return chunks
@@ -72,22 +62,25 @@ class DialogueChunker:
Args: Args:
dialogue: The processed DialogData object with chunks dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt) output_path: Optional path to save the output
Returns: Returns:
The path where the output was saved The path where the output was saved
""" """
if not output_path: if not output_path:
output_path = os.path.join(os.path.dirname(__file__), "..", "..", output_path = os.path.join(
f"chunker_output_{self.chunker_strategy.lower()}.txt") os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = [] output_lines = [
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===") f"=== Chunking Results ({self.chunker_strategy}) ===",
output_lines.append(f"Dialogue ID: {dialogue.ref_id}") f"Dialogue ID: {dialogue.ref_id}",
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages") f"Original conversation has {len(dialogue.context.msgs)} messages",
output_lines.append(f"Total characters: {len(dialogue.content)}") f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:") ]
for i, chunk in enumerate(dialogue.chunks): for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...") output_lines.append(f" Content preview: {chunk.content}...")

View File

@@ -5,8 +5,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import ( from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS, LABEL_DEFINITIONS,
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel): class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM""" """Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text") statement: str = Field(..., description="The extracted statement text")
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION") statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT") relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
class StatementExtractionResponse(BaseModel): class StatementExtractionResponse(BaseModel):
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements") statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
return v return v
class StatementExtractor: class StatementExtractor:
"""Class for extracting statements from dialog chunks using LLM (relations separated)""" """Class for extracting statements from dialog chunks using LLM"""
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None): def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
"""Initialize the StatementExtractor with an LLM client and configuration """Initialize the StatementExtractor with an LLM client and configuration
Args: Args:
@@ -71,6 +67,21 @@ class StatementExtractor:
self.llm_client = llm_client self.llm_client = llm_client
self.config = config or StatementExtractionConfig() self.config = config or StatementExtractionConfig()
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
"""Get speaker directly from Chunk
Args:
chunk: Chunk object containing speaker field
Returns:
Speaker role ("user"/"assistant") or None if cannot be determined
"""
if hasattr(chunk, 'speaker') and chunk.speaker:
return chunk.speaker
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements """Process a single chunk and return extracted statements
@@ -82,10 +93,12 @@ class StatementExtractor:
Returns: Returns:
List of ExtractedStatement objects extracted from the chunk List of ExtractedStatement objects extracted from the chunk
""" """
# Prepare the chunk content for processing
chunk_content = chunk.content chunk_content = chunk.content
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
return []
# Render the prompt using helper function
prompt_content = await render_statement_extraction_prompt( prompt_content = await render_statement_extraction_prompt(
chunk_content=chunk_content, chunk_content=chunk_content,
definitions=LABEL_DEFINITIONS, definitions=LABEL_DEFINITIONS,
@@ -136,7 +149,9 @@ class StatementExtractor:
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
except (KeyError, ValueError): except (KeyError, ValueError):
relevence_info = RelevenceInfo.RELEVANT relevence_info = RelevenceInfo.RELEVANT
chunk_speaker = self._get_speaker_from_chunk(chunk)
chunk_statement = Statement( chunk_statement = Statement(
statement=extracted_stmt.statement, statement=extracted_stmt.statement,
stmt_type=stmt_type, stmt_type=stmt_type,
@@ -144,7 +159,9 @@ class StatementExtractor:
relevence_info=relevence_info, relevence_info=relevence_info,
chunk_id=chunk.id, chunk_id=chunk.id,
group_id=group_id, group_id=group_id,
speaker=chunk_speaker,
) )
chunk_statements.append(chunk_statement) chunk_statements.append(chunk_statement)
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata # 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
@@ -226,12 +243,7 @@ class StatementExtractor:
return output_path return output_path
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str: def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
"""按对话分组聚合强/弱关系并写入 TXT 文件。 """Group and aggregate strong/weak relations by dialogue and write to TXT file."""
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
"""
print("\n=== Relations Classify ===") print("\n=== Relations Classify ===")
# 使用全局配置的输出路径 # 使用全局配置的输出路径

View File

@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
start_time = time.time() start_time = time.time()
try: try:
# First check if model exists at all (without tenant filtering) # OPTIMIZED: Single query with tenant filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) # We'll check tenant mismatch in the error handling
# Then check with tenant filtering
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
if not model: if not model:
# Model not found with tenant filter - check if it exists without filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
if model_without_tenant: if model_without_tenant:
# Model exists but belongs to different tenant # Model exists but belongs to different tenant
logger.warning( logger.warning(
@@ -208,8 +209,11 @@ def validate_embedding_model(
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> UUID: ) -> tuple[UUID, str]:
"""Validate that embedding model is available and return its UUID. """Validate that embedding model is available and return its UUID and name.
Returns:
Tuple of (embedding_uuid, embedding_name)
Raises: Raises:
InvalidConfigError: If embedding_id is not provided or invalid InvalidConfigError: If embedding_id is not provided or invalid
@@ -225,14 +229,19 @@ def validate_embedding_model(
workspace_id=workspace_id workspace_id=workspace_id
) )
embedding_uuid, _ = validate_and_resolve_model_id( embedding_uuid, embedding_name = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True, embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id config_id=config_id, workspace_id=workspace_id
) )
print(100*'-')
print(embedding_uuid) logger.debug(
print(_) "Embedding model validated",
print(100*'-') extra={
"embedding_uuid": str(embedding_uuid),
"embedding_name": embedding_name,
"config_id": config_id
}
)
if embedding_uuid is None: if embedding_uuid is None:
raise InvalidConfigError( raise InvalidConfigError(
@@ -243,7 +252,7 @@ def validate_embedding_model(
workspace_id=workspace_id workspace_id=workspace_id
) )
return embedding_uuid return embedding_uuid, embedding_name
def validate_llm_model( def validate_llm_model(

View File

@@ -104,38 +104,6 @@ class DataConfigRepository:
r.statement AS statement r.statement AS statement
""" """
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
{
entity_idx: n.entity_idx,
connect_strength: n.connect_strength,
description: n.description,
entity_type: n.entity_type,
name: n.name,
fact_summary: COALESCE(n.fact_summary, ''),
id: n.id
} AS sourceNode,
{
rel_id: elementId(r),
source_id: startNode(r).id,
target_id: endNode(r).id,
predicate: r.predicate,
statement_id: r.statement_id,
statement: r.statement
} AS edge,
{
entity_idx: m.entity_idx,
connect_strength: m.connect_strength,
description: m.description,
entity_type: m.entity_type,
name: m.name,
fact_summary: COALESCE(m.fact_summary, ''),
id: m.id
} AS targetNode
"""
@staticmethod @staticmethod
def update_reflection_config( def update_reflection_config(
db: Session, db: Session,

View File

@@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]
end_user = repo.get_end_user_by_id(end_user_id) end_user = repo.get_end_user_by_id(end_user_id)
return end_user return end_user
def update_end_user_other_name(
db: Session,
end_user_id: uuid.UUID,
other_name: str
) -> int:
"""
通过 end_user_id 更新 end_user 表中的 other_name 字段
Args:
db: 数据库会话
end_user_id: 宿主ID
other_name: 要更新的用户名
Returns:
int: 更新的记录数
"""
try:
# 执行更新
updated_count = (
db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{EndUser.other_name: other_name},
synchronize_session=False
)
)
db.commit()
db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}")
return updated_count
except Exception as e:
db.rollback()
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
raise
# 新增的缓存操作函数(保持与类方法一致的接口) # 新增的缓存操作函数(保持与类方法一致的接口)
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据ID获取终端用户用于缓存操作""" """根据ID获取终端用户用于缓存操作"""

View File

@@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None, "statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理 # 添加情绪字段处理
"emotion_type": statement.emotion_type, "emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity, "emotion_intensity": statement.emotion_intensity,
@@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number, "sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"), "start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index") "end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
} }
flattened_chunks.append(flattened_chunk) flattened_chunks.append(flattened_chunk)

View File

@@ -305,12 +305,19 @@ async def search_graph(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation( # Skip activation updates if only searching summaries (optimization)
connector=connector, needs_activation_update = any(
results=results, key in include and key in results and results[key]
group_id=group_id for key in ['statements', 'entities', 'chunks']
) )
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results return results
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s") logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []} return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time() # Skip activation updates if only searching summaries (optimization)
results = await _update_search_results_activation( needs_activation_update = any(
connector=connector, key in include and key in results and results[key]
results=results, for key in ['statements', 'entities', 'chunks']
group_id=group_id
) )
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s") if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
else:
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
print(f"query_text不能为空") logger.warning(f"query_text cannot be empty")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
print(f"查询结果为:\n{statements}") logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Temporal search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
print(f"dialog_id不能为空") logger.warning(f"dialog_id cannot be empty")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
print(f"chunk_id不能为空") logger.warning(f"chunk_id cannot be empty")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}

View File

@@ -12,10 +12,6 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel): class Write_UserInput(BaseModel):
message: str messages: list[dict]
group_id: str group_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class End_User_Information(BaseModel):
end_user_name: str # 这是要更新的用户名
id: str # 宿主ID用于匹配条件

View File

@@ -10,11 +10,6 @@ import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger() logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel): class KnowledgeRetrievalInput(BaseModel):
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"app.core.memory.agent.read_message", "app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
) )
result = task_service.get_task_memory_read_result(task.id) # result = task_service.get_task_memory_read_result(task.id)
status = result.get("status") # status = result.get("status")
logger.info(f"读取任务状态:{status}") # logger.info(f"读取任务状态:{status}")
finally: finally:
db.close() db.close()

View File

@@ -10,27 +10,32 @@ import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
import redis
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_base_service import Translation_English from app.services.memory_base_service import Translation_English
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -260,13 +265,13 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
""" """
Process write operation with config_id Process write operation with config_id
Args: Args:
group_id: Group identifier (also used as end_user_id) group_id: Group identifier (also used as end_user_id)
message: Message to write messages: Structured message list [{"role": "user", "content": "..."}, ...]
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag) storage_type: Storage type (neo4j or rag)
@@ -287,7 +292,7 @@ class MemoryAgentService:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise
logger.error(f"Failed to get connected config for end_user {group_id}: {e}") logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
@@ -315,14 +320,28 @@ class MemoryAgentService:
try: try:
if storage_type == "rag": if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id) # For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(group_id, message_text, user_rag_memory_id)
return result return result
else: else:
async with make_write_graph() as graph: async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": group_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, initial_state = {
"memory_config": memory_config} "messages": langchain_messages,
"group_id": group_id,
"memory_config": memory_config
}
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(
@@ -335,7 +354,9 @@ class MemoryAgentService:
massages = node_data massages = node_data
massagesstatus = massages.get('write_result')['status'] massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result') contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) # Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
@@ -386,6 +407,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
@@ -409,13 +431,15 @@ class MemoryAgentService:
audit_logger = None audit_logger = None
config_load_start = time.time()
try: try:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -439,6 +463,7 @@ class MemoryAgentService:
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": group_id}}
@@ -495,12 +520,68 @@ class MemoryAgentService:
if summary_n and summary_n != [] and summary_n != {}: if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n) _intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs) optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs) result = reorder_output_results(optimized_outputs)
# 保存短期记忆到数据库
# 只有 search_switch 不为 "2"(快速检索)时才保存
try:
from app.repositories.memory_short_repository import ShortTermMemoryRepository
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
reranked_results = raw_results.get('reranked_results', [])
try:
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id,
messages=message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation # Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -518,7 +599,8 @@ class MemoryAgentService:
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}" error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg) total_time = time.time() - start_time
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -531,7 +613,49 @@ class MemoryAgentService:
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
Args:
user_input: Write_UserInput object
Returns:
list[dict]: Message list, each message contains role and content
Raises:
ValueError: If messages is empty or format is incorrect
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
""" """
@@ -559,7 +683,67 @@ class MemoryAgentService:
logger.debug(f"Message type: {status}") logger.debug(f"Message type: {status}")
return status return status
# ==================== 新增的三个接口方法 ==================== async def generate_summary_from_retrieve(
self,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
Args:
retrieve_info: 检索到的信息
history: 历史对话记录
query: 用户查询
config_id: 配置ID
db: 数据库会话
Returns:
生成的答案文本
"""
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
history=history,
retrieve_info=retrieve_info,
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats( async def get_knowledge_type_stats(
self, self,
@@ -1033,7 +1217,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
""" """
from app.models.app_release_model import AppRelease from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig from app.models.data_config_model import DataConfig
from sqlalchemy import select from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
@@ -1091,8 +1275,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 批量查询 memory_config_name # 批量查询 memory_config_name
config_id_to_name = {} config_id_to_name = {}
if memory_config_ids: if memory_config_ids:
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs}
# 4. 构建最终结果 # 4. 构建最终结果
for end_user_id, app_id in user_to_app.items(): for end_user_id, app_id in user_to_app.items():
@@ -1109,7 +1293,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称 # 获取配置名称
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
result[end_user_id] = { result[end_user_id] = {
"memory_config_id": memory_config_id, "memory_config_id": memory_config_id,

View File

@@ -125,7 +125,11 @@ class MemoryConfigService:
try: try:
validated_config_id = _validate_config_id(config_id) validated_config_id = _validate_config_id(config_id)
# Step 1: Get config and workspace
db_query_start = time.time()
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result: if not result:
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
config_logger.error( config_logger.error(
@@ -144,16 +148,20 @@ class MemoryConfigService:
memory_config, workspace = result memory_config, workspace = result
# Validate embedding model # Step 2: Validate embedding model (returns both UUID and name)
embedding_uuid = validate_embedding_model( embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id, validated_config_id,
memory_config.embedding_id, memory_config.embedding_id,
self.db, self.db,
workspace.tenant_id, workspace.tenant_id,
workspace.id, workspace.id,
) )
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Resolve LLM model # Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id( llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id, memory_config.llm_id,
"llm", "llm",
@@ -163,8 +171,11 @@ class MemoryConfigService:
config_id=validated_config_id, config_id=validated_config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Resolve optional rerank model # Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None rerank_uuid = None
rerank_name = None rerank_name = None
if memory_config.rerank_id: if memory_config.rerank_id:
@@ -177,16 +188,12 @@ class MemoryConfigService:
config_id=validated_config_id, config_id=validated_config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Get embedding model name # Note: embedding_name is now returned from validate_embedding_model above
embedding_name, _ = validate_model_exists_and_active( # No need for redundant query!
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object # Create immutable MemoryConfig object
config = MemoryConfig( config = MemoryConfig(

View File

@@ -506,27 +506,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
return result return result
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""搜索所有实体之间的关系网络group 维度)。"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
group_id=end_user_id,
)
# 对source_node 和 target_node 的 fact_summary进行截取只截取前三条的内容需要提取前三条“来源”
for item in result:
source_fact = item["sourceNode"]["fact_summary"]
target_fact = item["targetNode"]["fact_summary"]
# 截取前三条“来源”
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
# 与现有返回风格保持一致,携带搜索类型、数量与详情
data = {
"search_for": "entity_graph",
"num": len(result),
"detials": result,
}
return data
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,

View File

@@ -357,6 +357,101 @@ class UserMemoryService:
data[key] = UserMemoryService._datetime_to_timestamp(original_value) data[key] = UserMemoryService._datetime_to_timestamp(original_value)
return data return data
def update_end_user_profile(
self,
db: Session,
end_user_id: str,
profile_update: Any
) -> Dict[str, Any]:
"""
更新终端用户的基本信息
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
profile_update: 包含更新字段的 Pydantic 模型
Returns:
{
"success": bool,
"data": dict, # 更新后的用户档案数据
"error": Optional[str]
}
"""
try:
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return {
"success": False,
"data": None,
"error": "终端用户不存在"
}
# 获取更新数据(排除 end_user_id 字段)
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
from app.core.api_key_utils import timestamp_to_datetime
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
# 更新字段
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新时间戳
end_user.updated_at = datetime.now()
end_user.updatetime_profile = datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
from app.schemas.end_user_schema import EndUserProfileResponse
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return {
"success": True,
"data": self.convert_profile_to_dict_with_timestamp(profile_data),
"error": None
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"data": None,
"error": "无效的用户ID格式"
}
except Exception as e:
db.rollback()
logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return {
"success": False,
"data": None,
"error": str(e)
}
async def get_cached_memory_insight( async def get_cached_memory_insight(
self, self,
db: Session, db: Session,

View File

@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db.close() db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -472,13 +454,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
支持两种消息格式:
1. 字符串格式向后兼容message="user: xxx\nassistant: yyy"
2. 结构化消息列表推荐message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
Args: Args:
group_id: Group ID for the memory agent (also used as end_user_id) group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write message: Message to write (str or list[dict])
config_id: Optional configuration ID config_id: Optional configuration ID
storage_type: Storage type (neo4j/rag)
user_rag_memory_id: RAG memory ID
Returns: Returns:
Dict containing the result and metadata Dict containing the result and metadata
@@ -522,24 +510,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
db.close() db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -554,7 +525,6 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -594,53 +564,53 @@ def reflection_timer_task() -> None:
""" """
reflection_engine() reflection_engine()
# unused task
@celery_app.task(name="app.core.memory.agent.health.check_read_service") # @celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]: # def check_read_service_task() -> Dict[str, str]:
"""Call read_service and write latest status to Redis. # """Call read_service and write latest status to Redis.
Returns status data dict that gets written to Redis. # Returns status data dict that gets written to Redis.
""" # """
client = redis.Redis( # client = redis.Redis(
host=settings.REDIS_HOST, # host=settings.REDIS_HOST,
port=settings.REDIS_PORT, # port=settings.REDIS_PORT,
db=settings.REDIS_DB, # db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None # password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
) # )
try: # try:
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" # api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
payload = { # payload = {
"user_id": "健康检查", # "user_id": "健康检查",
"apply_id": "健康检查", # "apply_id": "健康检查",
"group_id": "健康检查", # "group_id": "健康检查",
"message": "你好", # "message": "你好",
"history": [], # "history": [],
"search_switch": "2", # "search_switch": "2",
} # }
resp = requests.post(api_url, json=payload, timeout=15) # resp = requests.post(api_url, json=payload, timeout=15)
ok = resp.status_code == 200 # ok = resp.status_code == 200
status = "Success" if ok else "Fail" # status = "Success" if ok else "Fail"
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" # msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
error = "" if ok else resp.text # error = "" if ok else resp.text
code = 0 if ok else 500 # code = 0 if ok else 500
except Exception as e: # except Exception as e:
status = "Fail" # status = "Fail"
msg = "接口请求失败" # msg = "接口请求失败"
error = str(e) # error = str(e)
code = 500 # code = 500
data = { # data = {
"status": status, # "status": status,
"msg": msg, # "msg": msg,
"error": error, # "error": error,
"code": str(code), # "code": str(code),
"time": str(int(time.time())), # "time": str(int(time.time())),
} # }
client.hset("memsci:health:read_service", mapping=data) # client.hset("memsci:health:read_service", mapping=data)
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) # client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
return data # return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all") @celery_app.task(name="app.controllers.memory_storage_controller.search_all")
@@ -905,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1049,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1142,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration "duration_seconds": duration
} }
# 运行异步函数 return asyncio.run(_run())
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()

View File

@@ -7,10 +7,6 @@ services:
- "8002:8000" - "8002:8000"
env_file: env_file:
- .env - .env
environment:
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes: volumes:
- ./files:/files - ./files:/files
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro
@@ -19,20 +15,53 @@ services:
networks: networks:
- default - default
- celery - celery
depends_on:
- worker-memory
- worker-document
# Celery worker # Memory worker - Memory read/write tasks (threads pool for asyncio)
worker: worker-memory:
image: redbear-mem-open:latest image: redbear-mem-open:latest
container_name: worker container_name: worker-memory
env_file: env_file:
- .env - .env
volumes: volumes:
- ./files:/files - ./files:/files
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
restart: unless-stopped restart: unless-stopped
networks: networks:
- celery - celery
# Document worker - Document parsing tasks (prefork for CPU-bound)
worker-document:
image: redbear-mem-open:latest
container_name: worker-document
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
container_name: celery-beat
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app beat --loglevel=info
restart: unless-stopped
networks:
- celery
depends_on:
- worker-memory
networks: networks:
celery: celery:

View File

@@ -13,6 +13,7 @@ dependencies = [
"bcrypt==5.0.0", "bcrypt==5.0.0",
"billiard==4.2.2", "billiard==4.2.2",
"celery==5.5.3", "celery==5.5.3",
"flower==2.0.1",
"cffi==2.0.0", "cffi==2.0.0",
"click==8.3.0", "click==8.3.0",
"click-didyoumean==0.3.1", "click-didyoumean==0.3.1",
@@ -138,6 +139,7 @@ dependencies = [
"python-calamine>=0.4.0", "python-calamine>=0.4.0",
"xlrd==2.0.2", "xlrd==2.0.2",
"deprecated>=1.3.1", "deprecated>=1.3.1",
"flower>=2.0.1",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -6,6 +6,7 @@ async-timeout==5.0.1
bcrypt==5.0.0 bcrypt==5.0.0
billiard==4.2.2 billiard==4.2.2
celery==5.5.3 celery==5.5.3
flower==2.0.1
cffi==2.0.0 cffi==2.0.0
click==8.3.0 click==8.3.0
click-didyoumean==0.3.1 click-didyoumean==0.3.1