Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

# Conflicts:
#	api/pyproject.toml
This commit is contained in:
Mark
2026-01-22 10:22:40 +08:00
44 changed files with 1524 additions and 884 deletions

View File

@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Acknowledgements & Community
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
## Community & Support
Join our community to ask questions, share your work, and connect with fellow developers.
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6)
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -1,4 +1,5 @@
import os
import platform
from datetime import timedelta
from urllib.parse import quote
@@ -14,27 +15,12 @@ celery_app = Celery(
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
@@ -52,36 +38,47 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# 结果过期时间
result_expires=3600, # 结果保存 1 小时
result_expires=3600, # 结果保存1小时
# 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
},
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
# 构建定时任务配置
beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,

View File

@@ -9,7 +9,9 @@ from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
@@ -160,9 +162,12 @@ async def write_server(
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
messages_list, # 传递结构化消息列表
config_id,
db,
storage_type,
@@ -219,9 +224,12 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
@@ -285,6 +293,19 @@ async def read_server(
storage_type,
user_rag_memory_id
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -564,8 +585,23 @@ async def status_type(
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
user_input.message,
last_user_message,
user_input.config_id,
db
)
@@ -616,8 +652,10 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
language_type: Optional[str] ="zh",
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session=Depends(get_db),
):
"""
获取指定用户的热门记忆标签
@@ -628,10 +666,22 @@ async def get_hot_memory_tags_by_user_api(
...
]
"""
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
language_type=language_type,
model_id=model_id,
limit=limit
)
return success(data=result, msg="获取热门记忆标签成功")
@@ -647,7 +697,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
获取工作空间下Popular Memory Tags,包含:
- name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签

View File

@@ -5,7 +5,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
@@ -40,54 +39,7 @@ def get_workspace_total_end_users(
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.post("/update/end_users", response_model=ApiResponse)
async def update_workspace_end_users(
user_input: End_User_Information,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新工作空间的宿主信息
"""
username = user_input.end_user_name # 要更新的用户名
end_user_input_id = user_input.id # 宿主ID
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
try:
# 导入更新函数
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
# 转换 end_user_id 为 UUID 类型
end_user_uuid = uuid.UUID(end_user_input_id)
# 直接更新数据库中的 other_name 字段
updated_count = update_end_user_other_name(
db=db,
end_user_id=end_user_uuid,
other_name=username
)
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
return success(
data={
"updated_count": updated_count,
"end_user_id": end_user_input_id,
"updated_other_name": username
},
msg=f"成功更新 {updated_count} 个宿主的信息"
)
except Exception as e:
api_logger.error(f"更新宿主信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新宿主信息失败: {str(e)}"
)

View File

@@ -20,6 +20,7 @@ router = APIRouter(
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
language_type:Optional[str] = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):

View File

@@ -28,7 +28,6 @@ from app.services.memory_storage_service import (
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from fastapi import APIRouter, Depends
@@ -412,21 +411,7 @@ async def search_entity_edges(
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)

View File

@@ -12,6 +12,7 @@ from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.memory_base_service import Translation_English
from app.services.user_memory_service import (
UserMemoryService,
analytics_memory_types,
@@ -20,7 +21,7 @@ from app.services.user_memory_service import (
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -44,6 +45,7 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: str,
language_type: str = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -53,10 +55,18 @@ async def get_memory_insight_report_api(
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
如需生成新的洞察报告,请使用专门的生成接口。
"""
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
@@ -72,6 +82,7 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str,
language_type: str="zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -81,10 +92,18 @@ async def get_user_summary_api(
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
"""
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -253,7 +272,6 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
@@ -278,7 +296,13 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
@@ -296,7 +320,6 @@ async def get_end_user_profile(
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -328,12 +351,11 @@ async def update_end_user_profile(
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 检查用户是否已选择工作空间
# 验证工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
@@ -343,65 +365,41 @@ async def update_end_user_profile(
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
# 调用 Service 层处理业务逻辑
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
if result["success"]:
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return success(data=result["data"], msg="更新成功")
else:
error_msg = result["error"]
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
# 根据错误类型映射到合适的业务错误码
if error_msg == "终端用户不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
elif error_msg == "无效的用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
else:
# 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,

View File

@@ -145,44 +145,98 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
session_id = store.save_session(
userid=end_user_end,
messages=messages,
apply_id=end_user_end,
group_id=end_user_end,
aimessages=aimessages
)
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_save(self,messages,end_user_end,aimessages):
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# end_user_end=f"Term_{end_user_end}"
# print(messages)
# print(aimessages)
# session_id = store.save_session(
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# group_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
# messagss_list=[]
# retrieved_content=[]
# for messages in history:
# query = messages.get("Query")
# aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
user_rag_memory_id)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def chat(
self,
@@ -227,29 +281,30 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -277,8 +332,10 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
await self.term_memory_save(message_chat,end_user_id,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
response = {
"content": content,
"model": self.model_name,
@@ -346,27 +403,27 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -418,8 +475,10 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
if memory_flag:
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
await self.term_memory_save(message_chat, end_user_id, full_content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -18,16 +18,19 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
@@ -36,10 +39,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
@@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
sentence=content,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
structured = await problem_service.call_llm_structured(
@@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
@@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
split_result_dict = []
for index, item in enumerate(json.loads(split_result)):
split_data = {
"id": f"Q{index+1}",
"id": f"Q{index + 1}",
"question": item['extended_question'],
"type": item['type'],
"reason": item['reason']
}
split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = {
@@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"original_query": content
}
}
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
@@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
result = {
"context": json.dumps([], ensure_ascii=False),
@@ -126,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error": error_details
}
}
# 返回更新后的状态包含spit_context字段
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
@@ -153,10 +157,10 @@ async def Problem_Extension(state: ReadState) -> ReadState:
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
@@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState:
}
}
return {"problem_extension": result}
return {"problem_extension": result}

View File

@@ -4,12 +4,11 @@ import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
search_params = {
"group_id": group_id,
"question": data,
"return_raw_results": True
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
}
try:

View File

@@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState:
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
state: WriteState containing messages, group_id, and memory_config
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
dict: Contains 'write_result' with status and data fields
"""
content=state.get('data','')
group_id=state.get('group_id','')
memory_config=state.get('memory_config', '')
messages = state.get('messages', [])
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result=await write(
content=content,
result = await write(
messages=structured_messages,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
@@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState:
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
write_result= {
write_result = {
"status": "success",
"data": content,
"data": structured_messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
return {"write_result":write_result}
return {"write_result": write_result}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
write_result= {
write_result = {
"status": "error",
"message": str(e),
}

View File

@@ -59,7 +59,6 @@ async def make_read_graph():
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)

View File

@@ -14,7 +14,6 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -27,18 +26,12 @@ async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
The workflow directly processes messages from the initial state
and saves them to Neo4j storage.
"""
workflow = StateGraph(WriteState)
workflow.add_node("content_input", content_input_write)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()

View File

@@ -162,7 +162,7 @@ class OptimizedLLMService:
return fallback_value
elif isinstance(fallback_value, dict):
return response_model(**fallback_value)
# 尝试创建空的响应模型
if hasattr(response_model, 'root'):
# RootModel类型
@@ -170,7 +170,7 @@ class OptimizedLLMService:
else:
# 普通BaseModel类型
return response_model()
except Exception as e:
logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略

View File

@@ -12,32 +12,49 @@ async def get_chunked_dialogs(
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
messages: list = None,
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy.
"""Generate chunks from structured messages using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks for each test entry
List of DialogData objects with generated chunks
"""
dialog_data_list = []
messages = []
messages.append(ConversationMessage(role="用户", msg=content))
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
conversation_messages = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
@@ -46,25 +63,11 @@ async def get_chunked_dialogs(
apply_id=apply_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list
return [dialog_data]

View File

@@ -29,25 +29,22 @@ logger = get_agent_logger(__name__)
async def write(
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args:
content: Dialogue content to process
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027"
"""
# Extract config values
@@ -89,7 +86,7 @@ async def write(
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
messages=messages,
ref_id=ref_id,
config_id=config_id,
)

View File

@@ -4,6 +4,7 @@ import os
import asyncio
import json
import numpy as np
import logging
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
# Initialize logger
logger = logging.getLogger(__name__)
class LLMChunker:
"""基于LLM的智能分块策略"""
"""LLM-based intelligent chunking strategy"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
Return results in JSON format with a chunks array, each chunk having a text field.
文本内容:
Text content:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
{"role": "user", "content": prompt}
]
@@ -171,8 +173,6 @@ class ChunkerClient:
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
@@ -186,100 +186,93 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
# Validate dialogue has messages
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Dialogue {dialogue.ref_id} has no messages. "
f"Cannot generate chunks from empty dialogue."
)
dialogue.chunks = []
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
raise ValueError(
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
try:
sub_chunks = self.chunker(msg_content)
except Exception as e:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
metadata={
"message_index": msg_idx,
"message_role": msg.role,
"sub_chunk_index": idx,
"total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
dialogue.chunks.append(chunk)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
# 消息不长,直接作为一个 chunk
chunk = Chunk(
content=f"{msg.role}: {msg_content}",
speaker=msg.role, # 直接继承角色
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"message_index": msg_idx,
"message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
dialogue.chunks.append(chunk)
# Validate we generated at least one chunk
if not dialogue.chunks:
raise ValueError(
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
f"All messages were either empty or too short. "
f"Messages count: {len(dialogue.context.msgs)}"
)
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
"""Evaluate chunking quality."""
if not getattr(dialogue, 'chunks', None):
return {}
@@ -304,11 +297,8 @@ class ChunkerClient:
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
"""Save chunking results to file with strategy name in filename."""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"

View File

@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config)
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
return response
except Exception as e:
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
config=config
)
logger.debug(f"使用 PydanticOutputParser 解析成功")
return parsed
except Exception as e:
logger.warning(
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
)
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
# 方法 2: 使用 LangChain 的 with_structured_output
template = """{question}"""
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
# 验证并返回结果
try:
return response_model.model_validate(parsed)
result = response_model.model_validate(parsed)
return result
except Exception:
# 如果已经是 Pydantic 实例,直接返回
if hasattr(parsed, "model_dump"):
return parsed
# 尝试从 JSON 解析
return response_model.model_validate_json(json.dumps(parsed))
result = response_model.model_validate_json(json.dumps(parsed))
return result
else:
logger.warning("with_structured_output 方法不可用")
except Exception as e:
logger.error(f"结构化输出失败: {e}")

View File

@@ -224,6 +224,7 @@ class StatementNode(Node):
chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology)
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
emotion_target: Optional emotion target (person or object name)
emotion_subject: Optional emotion subject (self/other/object)
@@ -249,6 +250,12 @@ class StatementNode(Node):
stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content")
# Speaker identification
speaker: Optional[str] = Field(
None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
)
# Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field(
None,

View File

@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
"""Represents a single message in a conversation.
Attributes:
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
msg: Text content of the message
"""
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.")
@@ -57,6 +57,7 @@ class Statement(BaseModel):
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
@@ -74,6 +75,7 @@ class Statement(BaseModel):
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
stmt_type: StatementType = Field(..., description="The type of the statement.")
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
Attributes:
id: Unique identifier for the chunk
text: List of messages in the chunk
content: The content of the chunk as a formatted string
speaker: The speaker/role for this chunk (user/assistant)
statements: List of statements extracted from this chunk
chunk_embedding: Optional embedding vector for the chunk
metadata: Additional metadata as key-value pairs
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a list of messages.
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a single message (1 Message = 1 Chunk).
Args:
messages: List of conversation messages
message: Single conversation message
metadata: Optional metadata dictionary
Returns:
Chunk instance with formatted content
Chunk instance with speaker directly from message.role
"""
if metadata is None:
metadata = {}
# Generate content from messages
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
return cls(text=messages, content=content, metadata=metadata)
return cls(
content=f"{message.role}: {message.msg}",
speaker=message.role,
metadata=metadata or {}
)
class DialogData(BaseModel):
"""Represents the complete data structure for a dialog record.

View File

@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取情绪信息(优化版:全局陈述句级并行)
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
Args:
dialog_data_list: 对话数据列表
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
Returns:
情绪信息映射列表,每个对话对应一个字典
"""
logger.info("开始情绪信息提取(全局陈述句级并行")
logger.info("开始情绪信息提取(仅处理用户消息")
# 收集所有陈述句及其配置
all_statements = []
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
if not data_config or not data_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
# 收集所有陈述句(只收集 speaker 为 "user" 的)
total_statements = 0
filtered_statements = 0
# 收集所有陈述句
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
total_statements += 1
# 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
@@ -1033,6 +1040,7 @@ class ExtractionOrchestrator:
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,

View File

@@ -22,12 +22,12 @@ class DialogueChunker:
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# 对于 LLMChunker需要传入 llm_client
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
@@ -41,29 +41,19 @@ class DialogueChunker:
Returns:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
# Defensive fallback: ensure at least one chunk is returned for non-empty content
try:
chunks = result_dialogue.chunks
except Exception:
chunks = []
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# If the dialogue has content, return a single fallback chunk built from messages
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
if content_str and len(content_str.strip()) > 0:
fallback_chunk = Chunk.from_messages(
dialogue.context.msgs,
metadata={
"fallback": "single_chunk",
"chunker_strategy": self.chunker_config.chunker_strategy,
"source": "DialogueChunkerFallback",
},
)
return [fallback_chunk]
# No content: return empty list
return []
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
return chunks
@@ -72,22 +62,25 @@ class DialogueChunker:
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
output_path: Optional path to save the output
Returns:
The path where the output was saved
"""
if not output_path:
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt")
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = []
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
output_lines.append(f"Total characters: {len(dialogue.content)}")
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")

View File

@@ -5,8 +5,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text")
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
class StatementExtractionResponse(BaseModel):
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
return v
class StatementExtractor:
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
"""Class for extracting statements from dialog chunks using LLM"""
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
"""Initialize the StatementExtractor with an LLM client and configuration
Args:
@@ -71,6 +67,21 @@ class StatementExtractor:
self.llm_client = llm_client
self.config = config or StatementExtractionConfig()
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
"""Get speaker directly from Chunk
Args:
chunk: Chunk object containing speaker field
Returns:
Speaker role ("user"/"assistant") or None if cannot be determined
"""
if hasattr(chunk, 'speaker') and chunk.speaker:
return chunk.speaker
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
@@ -82,10 +93,12 @@ class StatementExtractor:
Returns:
List of ExtractedStatement objects extracted from the chunk
"""
# Prepare the chunk content for processing
chunk_content = chunk.content
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
return []
# Render the prompt using helper function
prompt_content = await render_statement_extraction_prompt(
chunk_content=chunk_content,
definitions=LABEL_DEFINITIONS,
@@ -136,7 +149,9 @@ class StatementExtractor:
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
except (KeyError, ValueError):
relevence_info = RelevenceInfo.RELEVANT
chunk_speaker = self._get_speaker_from_chunk(chunk)
chunk_statement = Statement(
statement=extracted_stmt.statement,
stmt_type=stmt_type,
@@ -144,7 +159,9 @@ class StatementExtractor:
relevence_info=relevence_info,
chunk_id=chunk.id,
group_id=group_id,
speaker=chunk_speaker,
)
chunk_statements.append(chunk_statement)
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
@@ -226,12 +243,7 @@ class StatementExtractor:
return output_path
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
"""按对话分组聚合强/弱关系并写入 TXT 文件。
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
"""
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
print("\n=== Relations Classify ===")
# 使用全局配置的输出路径

View File

@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
start_time = time.time()
try:
# First check if model exists at all (without tenant filtering)
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
# Then check with tenant filtering
# OPTIMIZED: Single query with tenant filter
# We'll check tenant mismatch in the error handling
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
# Model not found with tenant filter - check if it exists without filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
if model_without_tenant:
# Model exists but belongs to different tenant
logger.warning(
@@ -208,8 +209,11 @@ def validate_embedding_model(
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that embedding model is available and return its UUID.
) -> tuple[UUID, str]:
"""Validate that embedding model is available and return its UUID and name.
Returns:
Tuple of (embedding_uuid, embedding_name)
Raises:
InvalidConfigError: If embedding_id is not provided or invalid
@@ -225,14 +229,19 @@ def validate_embedding_model(
workspace_id=workspace_id
)
embedding_uuid, _ = validate_and_resolve_model_id(
embedding_uuid, embedding_name = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
print(100*'-')
print(embedding_uuid)
print(_)
print(100*'-')
logger.debug(
"Embedding model validated",
extra={
"embedding_uuid": str(embedding_uuid),
"embedding_name": embedding_name,
"config_id": config_id
}
)
if embedding_uuid is None:
raise InvalidConfigError(
@@ -243,7 +252,7 @@ def validate_embedding_model(
workspace_id=workspace_id
)
return embedding_uuid
return embedding_uuid, embedding_name
def validate_llm_model(

View File

@@ -104,38 +104,6 @@ class DataConfigRepository:
r.statement AS statement
"""
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
{
entity_idx: n.entity_idx,
connect_strength: n.connect_strength,
description: n.description,
entity_type: n.entity_type,
name: n.name,
fact_summary: COALESCE(n.fact_summary, ''),
id: n.id
} AS sourceNode,
{
rel_id: elementId(r),
source_id: startNode(r).id,
target_id: endNode(r).id,
predicate: r.predicate,
statement_id: r.statement_id,
statement: r.statement
} AS edge,
{
entity_idx: m.entity_idx,
connect_strength: m.connect_strength,
description: m.description,
entity_type: m.entity_type,
name: m.name,
fact_summary: COALESCE(m.fact_summary, ''),
id: m.id
} AS targetNode
"""
@staticmethod
def update_reflection_config(
db: Session,

View File

@@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]
end_user = repo.get_end_user_by_id(end_user_id)
return end_user
def update_end_user_other_name(
db: Session,
end_user_id: uuid.UUID,
other_name: str
) -> int:
"""
通过 end_user_id 更新 end_user 表中的 other_name 字段
Args:
db: 数据库会话
end_user_id: 宿主ID
other_name: 要更新的用户名
Returns:
int: 更新的记录数
"""
try:
# 执行更新
updated_count = (
db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{EndUser.other_name: other_name},
synchronize_session=False
)
)
db.commit()
db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}")
return updated_count
except Exception as e:
db.rollback()
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
raise
# 新增的缓存操作函数(保持与类方法一致的接口)
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据ID获取终端用户用于缓存操作"""

View File

@@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
@@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index")
"end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
}
flattened_chunks.append(flattened_chunk)

View File

@@ -305,12 +305,19 @@ async def search_graph(
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
else:
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text cannot be empty")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
logger.debug(f"Temporal search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id cannot be empty")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id cannot be empty")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}

View File

@@ -11,6 +11,7 @@ class EmotionTagsRequest(BaseModel):
start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01")
end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31")
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionWordcloudRequest(BaseModel):
@@ -18,20 +19,24 @@ class EmotionWordcloudRequest(BaseModel):
group_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionHealthRequest(BaseModel):
"""获取情绪健康指数请求"""
group_id: str = Field(..., description="组ID")
time_range: str = Field("30d", description="时间范围7d/30d/90d")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求"""
group_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionGenerateSuggestionsRequest(BaseModel):
"""生成个性化情绪建议请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -44,6 +44,7 @@ class EndUserProfileResponse(BaseModel):
updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None)
class EndUserProfileUpdate(BaseModel):
"""终端用户基本信息更新请求模型"""
end_user_id: str = Field(description="终端用户ID")

View File

@@ -12,10 +12,6 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel):
message: str
messages: list[dict]
group_id: str
config_id: Optional[str] = None
class End_User_Information(BaseModel):
end_user_name: str # 这是要更新的用户名
id: str # 宿主ID用于匹配条件

View File

@@ -51,6 +51,7 @@ class EpisodicMemoryOverviewRequest(BaseModel):
"""情景记忆总览查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
time_range: str = Field(
default="all",
description="时间范围筛选可选值all, today, this_week, this_month"
@@ -70,3 +71,4 @@ class EpisodicMemoryDetailsRequest(BaseModel):
end_user_id: str = Field(..., description="终端用户ID")
summary_id: str = Field(..., description="情景记忆摘要ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -1,15 +1,19 @@
"""
显性记忆的请求和响应模型
"""
from typing import Optional
from pydantic import BaseModel, Field
class ExplicitMemoryOverviewRequest(BaseModel):
"""显性记忆总览查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class ExplicitMemoryDetailsRequest(BaseModel):
"""显性记忆详情查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
memory_id: str = Field(..., description="记忆ID情景记忆或语义记忆的ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -10,11 +10,6 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
# result = task_service.get_task_memory_read_result(task.id)
# status = result.get("status")
# logger.info(f"读取任务状态:{status}")
finally:
db.close()

View File

@@ -10,26 +10,32 @@ import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
import redis
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_base_service import Translation_English
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -259,13 +265,13 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier (also used as end_user_id)
message: Message to write
messages: Structured message list [{"role": "user", "content": "..."}, ...]
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
@@ -286,7 +292,7 @@ class MemoryAgentService:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
raise
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
@@ -314,14 +320,28 @@ class MemoryAgentService:
try:
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(group_id, message_text, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
"memory_config": memory_config}
initial_state = {
"messages": langchain_messages,
"group_id": group_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
@@ -334,7 +354,9 @@ class MemoryAgentService:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
@@ -385,6 +407,7 @@ class MemoryAgentService:
import time
start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
# Resolve config_id if None using end_user's connected config
if config_id is None:
@@ -408,13 +431,15 @@ class MemoryAgentService:
audit_logger = None
config_load_start = time.time()
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
@@ -438,6 +463,7 @@ class MemoryAgentService:
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
@@ -494,12 +520,68 @@ class MemoryAgentService:
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs)
# 保存短期记忆到数据库
# 只有 search_switch 不为 "2"(快速检索)时才保存
try:
from app.repositories.memory_short_repository import ShortTermMemoryRepository
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
reranked_results = raw_results.get('reranked_results', [])
try:
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id,
messages=message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -517,7 +599,8 @@ class MemoryAgentService:
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
total_time = time.time() - start_time
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -530,7 +613,49 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
Args:
user_input: Write_UserInput object
Returns:
list[dict]: Message list, each message contains role and content
Raises:
ValueError: If messages is empty or format is incorrect
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
@@ -558,7 +683,67 @@ class MemoryAgentService:
logger.debug(f"Message type: {status}")
return status
# ==================== 新增的三个接口方法 ====================
async def generate_summary_from_retrieve(
self,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
Args:
retrieve_info: 检索到的信息
history: 历史对话记录
query: 用户查询
config_id: 配置ID
db: 数据库会话
Returns:
生成的答案文本
"""
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
history=history,
retrieve_info=retrieve_info,
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
@@ -692,7 +877,9 @@ class MemoryAgentService:
async def get_hot_memory_tags_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 20
limit: int = 20,
model_id: Optional[str] = None,
language_type: Optional[str] = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的热门记忆标签
@@ -710,7 +897,13 @@ class MemoryAgentService:
try:
# by_user=False 表示按 group_id 查询在Neo4j中group_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload = [{"name": t, "frequency": f} for t, f in tags]
payload=[]
for tag, freq in tags:
if language_type!="zh":
tag=await Translation_English(model_id, tag)
payload.append({"name": tag, "frequency": freq})
else:
payload.append({"name": tag, "frequency": freq})
return payload
except Exception as e:
logger.error(f"热门记忆标签查询失败: {e}")
@@ -1024,7 +1217,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from app.models.data_config_model import DataConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
@@ -1082,8 +1275,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 批量查询 memory_config_name
config_id_to_name = {}
if memory_config_ids:
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs}
# 4. 构建最终结果
for end_user_id, app_id in user_to_app.items():
@@ -1100,7 +1293,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
result[end_user_id] = {
"memory_config_id": memory_config_id,

View File

@@ -3,17 +3,268 @@ Memory Base Service
提供记忆服务的基础功能和共享辅助方法。
"""
import asyncio
import re
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
from app.db import get_db_context
logger = get_logger(__name__)
class TranslationResponse(BaseModel):
"""翻译响应模型"""
data: str
class MemoryTransService:
"""记忆翻译服务,提供中英文翻译功能"""
def __init__(self, llm_client=None, model_id: Optional[str] = None):
"""
初始化翻译服务
Args:
llm_client: LLM客户端实例或模型ID字符串可选
model_id: 模型ID用于初始化LLM客户端可选
Note:
- 如果llm_client是字符串会被当作model_id使用
- 如果同时提供llm_client和model_id优先使用llm_client
"""
# 处理llm_client参数如果是字符串当作model_id
if isinstance(llm_client, str):
self.model_id = llm_client
self.llm_client = None
else:
self.llm_client = llm_client
self.model_id = model_id
self._initialized = False
def _ensure_llm_client(self):
"""确保LLM客户端已初始化"""
if self._initialized:
return
if self.llm_client is None:
if self.model_id:
with get_db_context() as db:
config_service = MemoryConfigService(db)
model_config = config_service.get_model_config(self.model_id)
extra_params = {
"temperature": 0.2,
"max_tokens": 400,
"top_p": 0.8,
"stream": False,
}
self.llm_client = OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url"),
timeout=model_config.get("timeout", 30),
max_retries=model_config.get("max_retries", 3),
extra_params=extra_params
),
type_=model_config.get("type")
)
else:
raise ValueError("必须提供 llm_client 或 model_id 之一")
self._initialized = True
async def translate_to_english(self, text: str) -> str:
"""
将中文翻译为英文
Args:
text: 要翻译的中文文本
Returns:
翻译后的英文文本
"""
self._ensure_llm_client()
translation_messages = [
{
"role": "user",
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
}
]
try:
response = await self.llm_client.response_structured(
messages=translation_messages,
response_model=TranslationResponse
)
return response.data
except Exception as e:
logger.error(f"翻译失败: {str(e)}")
return text # 翻译失败时返回原文
async def is_english(self, text: str) -> bool:
"""
检查文本是否为英文
Args:
text: 要检查的文本(必须是字符串)
Returns:
True 如果文本主要是英文False 否则
Note:
- 只接受字符串类型
- 检查是否主要由英文字母和常见标点组成
- 允许数字、空格和常见标点符号
"""
if not isinstance(text, str):
raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}")
if not text.strip():
return True # 空字符串视为英文
# 更宽松的英文检查:允许字母、数字、空格和常见标点
# 如果文本中英文字符占比超过 80%,认为是英文
english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-'))
total_chars = len(text)
if total_chars == 0:
return True
return (english_chars / total_chars) >= 0.8
async def Translate(self, text: str, target_language: str = "en") -> str:
"""
通用翻译方法(保持向后兼容)
Args:
text: 要翻译的文本
target_language: 目标语言,"en"表示英文,"zh"表示中文
Returns:
翻译后的文本
"""
if target_language == "en":
return await self.translate_to_english(text)
else:
logger.warning(f"不支持的目标语言: {target_language},返回原文")
return text
# 测试翻译服务
async def Translation_English(modid, text, fields=None):
"""
将数据翻译为英文(支持字段级翻译)
Args:
modid: 模型ID
text: 要翻译的数据(可以是字符串、字典或列表)
fields: 需要翻译的字段列表(可选)
如果为None默认翻译: ['content', 'summary', 'statement', 'description',
'name', 'aliases', 'caption', 'emotion_keywords']
Returns:
翻译后的数据,保持原有结构
Note:
- 对于字符串:直接翻译
- 对于列表:递归处理每个元素,保持列表长度和索引不变
- 对于字典只翻译指定字段fields参数
- 对于其他类型:原样返回
"""
trans_service = MemoryTransService(modid)
# 处理字符串类型
if isinstance(text, str):
# 空字符串直接返回
if not text.strip():
return text
try:
is_eng = await trans_service.is_english(text)
if not is_eng:
english_result = await trans_service.Translate(text)
return english_result
return text
except Exception as e:
logger.warning(f"翻译字符串失败: {e}")
return text
# 处理列表类型
elif isinstance(text, list):
english_result = []
for item in text:
# 递归处理列表中的每个元素
if isinstance(item, str):
# 字符串元素:检查是否需要翻译
if not item.strip():
english_result.append(item)
continue
try:
is_eng = await trans_service.is_english(item)
if not is_eng:
translated = await trans_service.Translate(item)
english_result.append(translated)
else:
# 保留英文项,不改变列表长度
english_result.append(item)
except Exception as e:
logger.warning(f"翻译列表项失败: {e}")
english_result.append(item)
elif isinstance(item, dict):
# 字典元素:递归调用自己处理字典
translated_dict = await Translation_English(modid, item, fields)
english_result.append(translated_dict)
elif isinstance(item, list):
# 嵌套列表:递归处理
translated_list = await Translation_English(modid, item, fields)
english_result.append(translated_list)
else:
# 其他类型(数字、布尔值等):原样保留
english_result.append(item)
return english_result
# 处理字典类型
elif isinstance(text, dict):
# 确定要翻译的字段
if fields is None:
# 默认翻译字段
fields = [
'content', 'summary', 'statement', 'description',
'name', 'aliases', 'caption', 'emotion_keywords',
'text', 'title', 'label', 'type' # 添加常用字段
]
# 创建副本,避免修改原始数据
result = text.copy()
for field in fields:
if field in result and result[field] is not None:
# 递归翻译字段值(可能是字符串、列表或嵌套字典)
try:
result[field] = await Translation_English(modid, result[field], fields)
except Exception as e:
logger.warning(f"翻译字段 {field} 失败: {e}")
# 翻译失败时保留原值
continue
return result
# 其他类型数字、布尔值、None等原样返回
else:
return text
class MemoryBaseService:
"""记忆服务基类,提供共享的辅助方法"""
@@ -294,4 +545,4 @@ class MemoryBaseService:
except Exception as e:
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
return 0
return 0

View File

@@ -125,7 +125,11 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
# Step 1: Get config and workspace
db_query_start = time.time()
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -144,16 +148,20 @@ class MemoryConfigService:
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
# Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Resolve LLM model
# Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
@@ -163,8 +171,11 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Resolve optional rerank model
# Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
@@ -177,16 +188,12 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Note: embedding_name is now returned from validate_embedding_model above
# No need for redundant query!
# Create immutable MemoryConfig object
config = MemoryConfig(

View File

@@ -16,6 +16,7 @@ import json
from datetime import datetime
from app.schemas.memory_episodic_schema import EmotionType
from app.services.memory_base_service import Translation_English
logger = logging.getLogger(__name__)
@@ -24,7 +25,7 @@ class MemoryEntityService:
self.id = id
self.table = table
self.connector = Neo4jConnector()
async def get_timeline_memories_server(self):
async def get_timeline_memories_server(self,model_id, language_type):
"""
获取时间线记忆数据
@@ -48,10 +49,10 @@ class MemoryEntityService:
logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}")
# 根据表类型选择查询
if self.table == 'Statement':
if self.table == 'Statement':
# Statement只需要输入ID使用简化查询
results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id)
elif self.table == 'ExtractedEntity':
elif self.table == 'ExtractedEntity':
# ExtractedEntity类型查询
results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id)
else:
@@ -62,7 +63,7 @@ class MemoryEntityService:
logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}")
# 处理查询结果
timeline_data = self._process_timeline_results(results)
timeline_data =await self._process_timeline_results(results, model_id, language_type)
logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))}")
@@ -71,12 +72,14 @@ class MemoryEntityService:
except Exception as e:
logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True)
return str(e)
def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]:
"""
处理时间线查询结果
Args:
results: Neo4j查询结果
model_id: 模型ID用于翻译
language_type: 语言类型 ('zh' 或其他)
Returns:
处理后的时间线数据字典
@@ -104,19 +107,19 @@ class MemoryEntityService:
# 处理MemorySummary
summary = data.get('MemorySummary')
if summary is not None:
processed_summary = self._process_field_value(summary, "MemorySummary")
processed_summary = await self._process_field_value(summary, "MemorySummary")
memory_summary_list.extend(processed_summary)
# 处理Statement
statement = data.get('statement')
if statement is not None:
processed_statement = self._process_field_value(statement, "Statement")
processed_statement = await self._process_field_value(statement, "Statement")
statement_list.extend(processed_statement)
# 处理ExtractedEntity
extracted_entity = data.get('ExtractedEntity')
if extracted_entity is not None:
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity")
extracted_entity_list.extend(processed_entity)
# 去重 - 现在处理的是字典列表,需要更智能的去重
@@ -128,6 +131,8 @@ class MemoryEntityService:
all_timeline_data = memory_summary_list + statement_list
all_timeline_data = self._merge_same_text_items(all_timeline_data)
# 如果需要翻译(非中文),对整个结果进行翻译
result = {
"MemorySummary": memory_summary_list,
"Statement": statement_list,
@@ -233,7 +238,7 @@ class MemoryEntityService:
except Exception:
return False
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
"""
处理字段值,支持字符串、列表等类型
@@ -251,13 +256,13 @@ class MemoryEntityService:
# 如果是列表,处理每个元素
for item in value:
if self._is_valid_item(item):
processed_item = self._process_single_item(item)
processed_item = await self._process_single_item(item)
if processed_item:
processed_values.append(processed_item)
elif isinstance(value, dict):
# 如果是字典,直接处理
if self._is_valid_item(value):
processed_item = self._process_single_item(value)
processed_item = await self._process_single_item(value)
if processed_item:
processed_values.append(processed_item)
elif isinstance(value, str):
@@ -304,7 +309,7 @@ class MemoryEntityService:
return (str(item).strip() != '' and
"MemorySummaryChunk" not in str(item))
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
处理单个项目
@@ -369,6 +374,117 @@ class MemoryEntityService:
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
return str(dt) if dt is not None else None
async def _translate_list(
self,
data_list: List[Dict[str, Any]],
model_id: str,
fields: List[str]
) -> List[Dict[str, Any]]:
"""
翻译列表中每个字典的指定字段(并发有限度以降低整体延迟)
Args:
data_list: 要翻译的字典列表
model_id: 模型ID
fields: 需要翻译的字段列表
Returns:
翻译后的字典列表
"""
# 空列表或无字段时直接返回
if not data_list or not fields:
return data_list
import asyncio
# 并发限制,避免一次性发起过多请求
# 可根据实际情况调整(建议 5-10
concurrency_limit = 5
semaphore = asyncio.Semaphore(concurrency_limit)
async def translate_single_field(
index: int,
field: str,
value: Any,
) -> Optional[tuple]:
"""
翻译单个字段并返回 (索引, 字段名, 翻译结果)
Returns:
(index, field, translated_value) 或 None如果跳过
"""
# 跳过空值
if value is None or value == "":
return None
# 统一转成字符串再翻译,防止非字符串类型导致错误
text = str(value)
try:
async with semaphore:
# 调用 Translation_English 进行翻译
# 注意Translation_English 的参数顺序是 (model_id, text)
translated = await Translation_English(model_id, text)
# 如果翻译结果为空,保留原值
if translated is None or translated == "":
return None
return index, field, translated
except Exception as e:
logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}")
return None
# 构造所有需要翻译的任务
tasks = []
for idx, item in enumerate(data_list):
# 防御性检查:确保 item 是字典
if not isinstance(item, dict):
continue
for field in fields:
if field not in item:
continue
value = item.get(field)
# 对于 None 或空字符串的值,直接跳过,不创建任务
if value is None or value == "":
continue
tasks.append(
asyncio.create_task(
translate_single_field(idx, field, value)
)
)
# 如果没有需要翻译的任务,直接返回原列表
if not tasks:
return data_list
# 使用 gather 并发执行翻译任务(受 semaphore 限制)
# return_exceptions=True 可以防止单个任务失败导致整体失败
results = await asyncio.gather(*tasks, return_exceptions=True)
# 创建深拷贝以避免修改原始数据
translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list]
# 将翻译结果回填到列表
for result in results:
# 跳过 None 结果和异常
if result is None or isinstance(result, Exception):
if isinstance(result, Exception):
logger.warning(f"翻译任务异常: {result}")
continue
idx, field, translated = result
# 防御性检查索引范围
if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict):
translated_list[idx][field] = translated
return translated_list
@@ -426,15 +542,19 @@ class MemoryEmotion:
# 如果解析失败,返回原始字符串
return iso_string
async def get_emotion(self) -> Dict[str, Any]:
async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]:
"""
获取情绪随时间变化数据
Args:
model_id: 模型ID用于翻译
language_type: 语言类型 ('zh' 或其他)
Returns:
包含情绪数据的字典
"""
try:
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}")
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}")
if self.table == 'Statement':
results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id)
@@ -450,6 +570,10 @@ class MemoryEmotion:
# 转换Neo4j类型
final_data = self._convert_neo4j_types(emotion_data)
# 如果需要翻译(非中文)
if language_type != 'zh' and model_id and final_data:
final_data = await self._translate_emotion_data(final_data, model_id)
logger.info(f"成功获取 {len(final_data)} 条情绪数据")
return final_data
@@ -590,16 +714,14 @@ class MemoryInteraction:
"""
try:
logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}")
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
if ori_data!=[]:
# name = ori_data[0]['name']
group_id = ori_data[0]['group_id']
group_id = [i['group_id'] for i in ori_data][0]
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
if not Space_User:
return []
user_id=Space_User[0]['id']
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)

View File

@@ -506,27 +506,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
return result
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""搜索所有实体之间的关系网络group 维度)。"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
group_id=end_user_id,
)
# 对source_node 和 target_node 的 fact_summary进行截取只截取前三条的内容需要提取前三条“来源”
for item in result:
source_fact = item["sourceNode"]["fact_summary"]
target_fact = item["targetNode"]["fact_summary"]
# 截取前三条“来源”
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
# 与现有返回风格保持一致,携带搜索类型、数量与详情
data = {
"search_for": "entity_graph",
"num": len(result),
"detials": result,
}
return data
async def analytics_hot_memory_tags(
db: Session,

View File

@@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
@@ -357,10 +357,107 @@ class UserMemoryService:
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
return data
def update_end_user_profile(
self,
db: Session,
end_user_id: str,
profile_update: Any
) -> Dict[str, Any]:
"""
更新终端用户的基本信息
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
profile_update: 包含更新字段的 Pydantic 模型
Returns:
{
"success": bool,
"data": dict, # 更新后的用户档案数据
"error": Optional[str]
}
"""
try:
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return {
"success": False,
"data": None,
"error": "终端用户不存在"
}
# 获取更新数据(排除 end_user_id 字段)
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
from app.core.api_key_utils import timestamp_to_datetime
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
# 更新字段
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新时间戳
end_user.updated_at = datetime.now()
end_user.updatetime_profile = datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
from app.schemas.end_user_schema import EndUserProfileResponse
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return {
"success": True,
"data": self.convert_profile_to_dict_with_timestamp(profile_data),
"error": None
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"data": None,
"error": "无效的用户ID格式"
}
except Exception as e:
db.rollback()
logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return {
"success": False,
"data": None,
"error": str(e)
}
async def get_cached_memory_insight(
self,
db: Session,
end_user_id: str
end_user_id: str,
model_id: str,
language_type: str
) -> Dict[str, Any]:
"""
从数据库获取缓存的记忆洞察(四个维度)
@@ -419,11 +516,18 @@ class UserMemoryService:
key_findings_array = []
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)")
memory_insight=end_user.memory_insight
behavior_pattern=end_user.behavior_pattern
growth_trajectory=end_user.growth_trajectory
if language_type!='zh':
memory_insight=await Translation_English(model_id,memory_insight)
behavior_pattern=await Translation_English(model_id,behavior_pattern)
growth_trajectory=await Translation_English(model_id,growth_trajectory)
return {
"memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight
"behavior_pattern": end_user.behavior_pattern,
"memory_insight":memory_insight, # 总体概述存储在 memory_insight
"behavior_pattern":behavior_pattern,
"key_findings": key_findings_array, # 返回数组
"growth_trajectory": end_user.growth_trajectory,
"growth_trajectory": growth_trajectory,
"updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at),
"is_cached": True
}
@@ -457,7 +561,9 @@ class UserMemoryService:
async def get_cached_user_summary(
self,
db: Session,
end_user_id: str
end_user_id: str,
model_id:str,
language_type:str="zh"
) -> Dict[str, Any]:
"""
从数据库获取缓存的用户摘要(四个部分)
@@ -481,7 +587,6 @@ class UserMemoryService:
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
@@ -495,20 +600,29 @@ class UserMemoryService:
}
# 检查是否有缓存数据(至少有一个字段不为空)
user_summary=end_user.user_summary
personality_traits=end_user.personality_traits
core_values=end_user.core_values
one_sentence_summary=end_user.one_sentence_summary
if language_type!='zh':
user_summary=await Translation_English(model_id, user_summary)
personality_traits = await Translation_English(model_id, personality_traits)
core_values = await Translation_English(model_id, core_values)
one_sentence_summary = await Translation_English(model_id, one_sentence_summary)
has_cache = any([
end_user.user_summary,
end_user.personality_traits,
end_user.core_values,
end_user.one_sentence_summary
user_summary,
personality_traits,
core_values,
one_sentence_summary
])
if has_cache:
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
return {
"user_summary": end_user.user_summary,
"personality": end_user.personality_traits,
"core_values": end_user.core_values,
"one_sentence": end_user.one_sentence_summary,
"user_summary": user_summary,
"personality": personality_traits,
"core_values":core_values,
"one_sentence": one_sentence_summary,
"updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at),
"is_cached": True
}
@@ -1367,7 +1481,6 @@ async def analytics_memory_types(
return memory_types
async def analytics_graph_data(
db: Session,
end_user_id: str,
@@ -1557,7 +1670,7 @@ async def analytics_graph_data(
f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={len(nodes)}, edges={len(edges)}"
)
return {
"nodes": nodes,
"edges": edges,
@@ -1606,11 +1719,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
# 获取该节点类型的白名单字段
allowed_fields = field_whitelist.get(label, [])
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
# if not allowed_fields:
# # 对于未定义的节点类型,只返回基本字段
# allowed_fields = ["name", "created_at", "caption"]
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
node_results = await (_neo4j_connector.execute_query(count_neo4j))
# 提取白名单中的字段
@@ -1618,13 +1727,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
for field in allowed_fields:
if field in properties:
value = properties[field]
if str(field) == 'entity_type':
if str(field) == 'entity_type':
value=type_mapping.get(value,'')
if str(field)=="emotion_type":
value=EmotionType.EMOTION_MAPPING.get(value)
if str(field)=="emotion_subject":
if str(field)=="emotion_subject":
value=EmotionSubject.SUBJECT_MAPPING.get(value)
# 清理 Neo4j 特殊类型
filtered_props[field] = _clean_neo4j_value(value)
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
return filtered_props

View File

@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
return {
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -472,13 +454,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
支持两种消息格式:
1. 字符串格式向后兼容message="user: xxx\nassistant: yyy"
2. 结构化消息列表推荐message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
message: Message to write (str or list[dict])
config_id: Optional configuration ID
storage_type: Storage type (neo4j/rag)
user_rag_memory_id: RAG memory ID
Returns:
Dict containing the result and metadata
@@ -522,24 +510,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -554,7 +525,6 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -594,53 +564,53 @@ def reflection_timer_task() -> None:
"""
reflection_engine()
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]:
"""Call read_service and write latest status to Redis.
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]:
# """Call read_service and write latest status to Redis.
Returns status data dict that gets written to Redis.
"""
client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
)
try:
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
payload = {
"user_id": "健康检查",
"apply_id": "健康检查",
"group_id": "健康检查",
"message": "你好",
"history": [],
"search_switch": "2",
}
resp = requests.post(api_url, json=payload, timeout=15)
ok = resp.status_code == 200
status = "Success" if ok else "Fail"
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
error = "" if ok else resp.text
code = 0 if ok else 500
except Exception as e:
status = "Fail"
msg = "接口请求失败"
error = str(e)
code = 500
# Returns status data dict that gets written to Redis.
# """
# client = redis.Redis(
# host=settings.REDIS_HOST,
# port=settings.REDIS_PORT,
# db=settings.REDIS_DB,
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
# )
# try:
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
# payload = {
# "user_id": "健康检查",
# "apply_id": "健康检查",
# "group_id": "健康检查",
# "message": "你好",
# "history": [],
# "search_switch": "2",
# }
# resp = requests.post(api_url, json=payload, timeout=15)
# ok = resp.status_code == 200
# status = "Success" if ok else "Fail"
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
# error = "" if ok else resp.text
# code = 0 if ok else 500
# except Exception as e:
# status = "Fail"
# msg = "接口请求失败"
# error = str(e)
# code = 500
data = {
"status": status,
"msg": msg,
"error": error,
"code": str(code),
"time": str(int(time.time())),
}
# data = {
# "status": status,
# "msg": msg,
# "error": error,
# "code": str(code),
# "time": str(int(time.time())),
# }
client.hset("memsci:health:read_service", mapping=data)
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
# client.hset("memsci:health:read_service", mapping=data)
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
return data
# return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
@@ -905,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -1049,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -1142,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration
}
# 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()
return asyncio.run(_run())

View File

@@ -7,10 +7,6 @@ services:
- "8002:8000"
env_file:
- .env
environment:
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
@@ -19,20 +15,53 @@ services:
networks:
- default
- celery
depends_on:
- worker-memory
- worker-document
# Celery worker
worker:
# Memory worker - Memory read/write tasks (threads pool for asyncio)
worker-memory:
image: redbear-mem-open:latest
container_name: worker
container_name: worker-memory
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
restart: unless-stopped
networks:
- celery
# Document worker - Document parsing tasks (prefork for CPU-bound)
worker-document:
image: redbear-mem-open:latest
container_name: worker-document
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
container_name: celery-beat
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app beat --loglevel=info
restart: unless-stopped
networks:
- celery
depends_on:
- worker-memory
networks:
celery:

View File

@@ -139,6 +139,7 @@ dependencies = [
"xlrd==2.0.2",
"deprecated>=1.3.1",
"oss2>=2.19.1",
"flower>=2.0.1",
]
[tool.pytest.ini_options]

View File

@@ -6,6 +6,7 @@ async-timeout==5.0.1
bcrypt==5.0.0
billiard==4.2.2
celery==5.5.3
flower==2.0.1
cffi==2.0.0
click==8.3.0
click-didyoumean==0.3.1