dev新增短期记忆功能 (#47)

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能
This commit is contained in:
lixinyue11
2026-01-07 16:36:11 +08:00
committed by GitHub
parent 5fe8043ff8
commit bcb3d587a1
9 changed files with 765 additions and 45 deletions

View File

@@ -24,6 +24,7 @@ from . import (
memory_storage_controller, memory_storage_controller,
memory_dashboard_controller, memory_dashboard_controller,
memory_reflection_controller, memory_reflection_controller,
memory_short_term_controller,
api_key_controller, api_key_controller,
release_share_controller, release_share_controller,
public_share_controller, public_share_controller,
@@ -71,6 +72,7 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router) manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router) manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router) manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router) manager_router.include_router(home_page_controller.router)

View File

@@ -0,0 +1,44 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
)
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -7,13 +7,20 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.db import get_db
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
@@ -96,7 +103,8 @@ class LangChainAgent:
"temperature": temperature, "temperature": temperature,
"streaming": streaming, "streaming": streaming,
"tool_count": len(self.tools), "tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [] "tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
} }
) )
@@ -137,11 +145,8 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
async def term_memory_save(self,messages,end_user_end,aimessages): async def term_memory_save(self,messages,end_user_end,aimessages):
""" '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
end_user_end=f"Term_{end_user_end}" end_user_end=f"Term_{end_user_end}"
print(messages) print(messages)
print(aimessages) print(aimessages)
@@ -155,17 +160,18 @@ class LangChainAgent:
store.delete_duplicate_sessions() store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}') # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id return session_id
async def term_memory_redis_read(self,end_user_end): async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}" end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}') # logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[] messagss_list=[]
retrieved_content=[]
for messages in history: for messages in history:
query = messages.get("Query") query = messages.get("Query")
aimessages = messages.get("Answer") aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}') messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
return messagss_list retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
@@ -205,7 +211,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config # If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id: if actual_config_id is None and end_user_id:
try: try:
from app.db import get_db
from app.services.memory_agent_service import ( from app.services.memory_agent_service import (
get_end_user_connected_config, get_end_user_connected_config,
) )
@@ -223,11 +228,26 @@ class LangChainAgent:
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory=await self.term_memory_redis_read(end_user_id) history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag: if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag": if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory=';'.join(history_term_memory) history_term_memory = ';'.join(history_term_memory)
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try: try:
@@ -316,10 +336,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config # If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id: if actual_config_id is None and end_user_id:
try: try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db()) db = next(get_db())
try: try:
connected_config = get_end_user_connected_config(end_user_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
@@ -331,14 +347,24 @@ class LangChainAgent:
except Exception as e: except Exception as e:
logger.warning(f"Failed to get db session: {e}") logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id) history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag: if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag": if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory) history_term_memory = ';'.join(history_term_memory)
logger.info( retrieved_content = history_term_memory_result[1]
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') db_for_memory = next(get_db())
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, try:
history_term_memory, actual_config_id) repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try: try:

View File

@@ -6,6 +6,7 @@ from .document_model import Document
from .file_model import File from .file_model import File
from .generic_file_model import GenericFile from .generic_file_model import GenericFile
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
from .memory_short_model import ShortTermMemory, LongTermMemory
from .knowledgeshare_model import KnowledgeShare from .knowledgeshare_model import KnowledgeShare
from .app_model import App from .app_model import App
from .agent_app_config_model import AgentConfig from .agent_app_config_model import AgentConfig
@@ -67,6 +68,8 @@ __all__ = [
"BuiltinToolConfig", "BuiltinToolConfig",
"CustomToolConfig", "CustomToolConfig",
"MCPToolConfig", "MCPToolConfig",
"ShortTermMemory",
"LongTermMemory",
"ToolExecution", "ToolExecution",
"ToolType", "ToolType",
"ToolStatus", "ToolStatus",

View File

@@ -0,0 +1,60 @@
"""
记忆模型 - 短期记忆和长期记忆表
"""
import uuid
import datetime
from sqlalchemy import Column, String, DateTime, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class ShortTermMemory(Base):
"""短期记忆表
用于存储临时的对话记忆,通常保存较短时间
"""
__tablename__ = "memory_short_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 对话内容
messages = Column(Text, nullable=False, comment="用户消息内容")
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
# 搜索开关
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
class LongTermMemory(Base):
"""长期记忆表
用于存储重要的对话记忆,长期保存
"""
__tablename__ = "memory_long_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"

View File

@@ -0,0 +1,503 @@
"""
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import datetime
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class ShortTermMemoryRepository:
"""短期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建短期记忆记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建的短期记忆对象
"""
try:
memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
raise
def count_by_user_id(self,end_user_id: str) -> int:
"""根据ID获取短期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[ShortTermMemory]: 记忆对象如果不存在则返回None
"""
try:
count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
raise
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
"""获取用户最新的短期记忆记录
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认5条
Returns:
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.order_by(ShortTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
raise
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
"""获取用户最近指定小时内的短期记忆记录
Args:
end_user_id: 终端用户ID
hours: 时间范围小时默认24小时
Returns:
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.created_at >= cutoff_time
)
.order_by(ShortTermMemory.created_at.desc())
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的短期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def delete_old_memories(self, days: int = 7) -> int:
"""删除指定天数之前的短期记忆记录
Args:
days: 保留天数默认7天
Returns:
int: 删除的记录数
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.created_at < cutoff_time)
.delete(synchronize_session=False)
)
self.db.commit()
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
return deleted_count
except Exception as e:
self.db.rollback()
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建或更新短期记忆记录
根据 end_user_id、messages 和 aimessages 查找现有记录:
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
- 如果没有找到匹配的记录,则创建新记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建或更新的短期记忆对象
"""
try:
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
query_filters = [
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.messages == messages
]
# 如果 aimessages 不为空,则加入查询条件
if aimessages is not None:
query_filters.append(ShortTermMemory.aimessages == aimessages)
else:
# 如果 aimessages 为 None则查找 aimessages 为 NULL 的记录
query_filters.append(ShortTermMemory.aimessages.is_(None))
# 查找现有记录
existing_memory = (
self.db.query(ShortTermMemory)
.filter(*query_filters)
.first()
)
if existing_memory:
# 更新现有记录
existing_memory.messages = messages
existing_memory.aimessages = aimessages
existing_memory.search_switch = search_switch
existing_memory.retrieved_content = retrieved_content or []
self.db.commit()
self.db.refresh(existing_memory)
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
return existing_memory
else:
# 创建新记录
new_memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
raise
class LongTermMemoryRepository:
"""长期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
"""创建长期记忆记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
LongTermMemory: 创建的长期记忆对象
"""
try:
memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
raise
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
"""根据ID获取长期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[LongTermMemory]: 记忆对象如果不存在则返回None
"""
try:
memory = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.first()
)
if memory:
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
else:
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
"""根据用户ID获取长期记忆记录列表
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认100
offset: 偏移量默认0
Returns:
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_long_term_user_time 优化查询
memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id}{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
raise
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
"""根据内容关键词搜索长期记忆记录
Args:
end_user_id: 终端用户ID
keyword: 搜索关键词
limit: 返回记录数限制默认50
Returns:
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
"""
try:
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
memories = (
self.db.query(LongTermMemory)
.filter(
LongTermMemory.end_user_id == end_user_id,
LongTermMemory.retrieved_content.astext.contains(keyword)
)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}'{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的长期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def count_by_user_id(self, end_user_id: str) -> int:
"""统计用户的长期记忆记录数量
Args:
end_user_id: 终端用户ID
Returns:
int: 记录数量
"""
try:
count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
"""创建或更新长期记忆记录
根据 end_user_id 和 retrieved_content 判断是否需要写入:
- 如果找到相同的 end_user_id 和 retrieved_content则不写入返回 None
- 如果没有找到相同的记录,则创建新记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
"""
try:
retrieved_content = retrieved_content or []
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
existing_memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(100) # 限制查询数量,避免加载过多数据
.all()
)
# 在 Python 中比较 retrieved_content
for memory in existing_memories:
if memory.retrieved_content == retrieved_content:
# 如果找到相同的记录,不写入
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
return None
# 如果没有找到相同的记录,创建新记录
new_memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
raise

View File

@@ -4,6 +4,7 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services, Handles business logic for memory agent operations including read/write services,
health checks, and message type classification. health checks, and message type classification.
""" """
import datetime
import json import json
import os import os
import re import re
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -393,7 +395,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message=message
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
try: try:
@@ -406,15 +408,15 @@ class MemoryAgentService:
raise # Re-raise our specific error raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}") logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}") logger.info(f"Read operation for group {group_id} with config_id {config_id}")
# 导入审计日志记录器 # 导入审计日志记录器
try: try:
from app.core.memory.utils.log.audit_logger import audit_logger from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError: except ImportError:
audit_logger = None audit_logger = None
# Get group lock to prevent concurrent processing # Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id) group_lock = self.get_group_lock(group_id)
@@ -430,7 +432,7 @@ class MemoryAgentService:
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
@@ -442,9 +444,9 @@ class MemoryAgentService:
duration=duration, duration=duration,
error=error_msg error=error_msg
) )
raise ValueError(error_msg) raise ValueError(error_msg)
# Step 2: Prepare history # Step 2: Prepare history
history.append({"role": "user", "content": message}) history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -452,7 +454,7 @@ class MemoryAgentService:
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config() mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config) client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session: async with client.session('data_flow') as session:
logger.debug("Connected to MCP Server: data_flow") logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session) tools = await load_mcp_tools(session)
@@ -475,7 +477,7 @@ class MemoryAgentService:
# Capture any errors from the state # Capture any errors from the state
if event.get('errors'): if event.get('errors'):
workflow_errors.extend(event.get('errors', [])) workflow_errors.extend(event.get('errors', []))
for msg in messages: for msg in messages:
msg_content = msg.content msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "") msg_role = msg.__class__.__name__.lower().replace("message", "")
@@ -483,7 +485,7 @@ class MemoryAgentService:
"role": msg_role, "role": msg_role,
"content": msg_content "content": msg_content
}) })
# Extract intermediate outputs # Extract intermediate outputs
if hasattr(msg, 'content'): if hasattr(msg, 'content'):
try: try:
@@ -496,7 +498,7 @@ class MemoryAgentService:
break break
else: else:
continue # No text block found continue # No text block found
# Try to parse content as JSON # Try to parse content as JSON
if isinstance(content_to_parse, str): if isinstance(content_to_parse, str):
try: try:
@@ -506,16 +508,16 @@ class MemoryAgentService:
if '_intermediate' in parsed: if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate'] intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve) # Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed: if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']: for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
@@ -523,7 +525,7 @@ class MemoryAgentService:
pass pass
except Exception as e: except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}") logger.debug(f"Failed to extract intermediate output: {e}")
workflow_duration = time.time() - start workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s") logger.info(f"Read graph workflow completed in {workflow_duration}s")
@@ -532,7 +534,7 @@ class MemoryAgentService:
for messages in outputs: for messages in outputs:
if messages['role'] == 'tool': if messages['role'] == 'tool':
message = messages['content'] message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}] # Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list): if isinstance(message, list):
# Extract text from MCP content blocks # Extract text from MCP content blocks
@@ -542,7 +544,7 @@ class MemoryAgentService:
break break
else: else:
continue # No text block found continue # No text block found
try: try:
parsed = json.loads(message) if isinstance(message, str) else message parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict): if isinstance(parsed, dict):
@@ -552,15 +554,15 @@ class MemoryAgentService:
final_answer = summary_result final_answer = summary_result
except (json.JSONDecodeError, ValueError): except (json.JSONDecodeError, ValueError):
pass pass
# 记录成功的操作 # 记录成功的操作
total_duration = time.time() - start_time total_duration = time.time() - start_time
# Check for workflow errors # Check for workflow errors
if workflow_errors: if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}") logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger: if audit_logger:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
@@ -577,11 +579,11 @@ class MemoryAgentService:
"errors": workflow_errors "errors": workflow_errors
} }
) )
# Raise error if no answer was produced # Raise error if no answer was produced
if not final_answer: if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}") raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors: if audit_logger and not workflow_errors:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
@@ -596,7 +598,31 @@ class MemoryAgentService:
"has_answer": bool(final_answer) "has_answer": bool(final_answer)
} }
) )
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
statements=list(set(statements))
retrieved_content.append({query:statements})
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return { return {
"answer": final_answer, "answer": final_answer,
"intermediate_outputs": intermediate_outputs "intermediate_outputs": intermediate_outputs

View File

@@ -0,0 +1,56 @@
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.repositories.memory_short_repository import ShortTermMemoryRepository
api_logger = get_api_logger()
db=next(get_db())
class ShortService:
def __init__(self, end_user_id):
self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_short_databasets(self):
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = []
for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory
messages = memory.messages
aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or []
api_logger.debug(f"Retrieved content: {retrieved_content}")
retrieval_source = []
for item in retrieved_content:
if isinstance(item, dict):
for key, values in item.items():
retrieval_source.append({"query": key, "retrieval": values})
deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误
deep_expanded['answer'] = aimessages
short_result.append(deep_expanded)
return short_result
def get_short_count(self):
short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count
class LongService:
def __init__(self, end_user_id):
self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = []
for long_memory in long_memories:
if long_memory.retrieved_content:
for memory_item in long_memory.retrieved_content:
if isinstance(memory_item, dict):
for key, values in memory_item.items():
long_result.append({"query": key, "retrieval": values})
return long_result

View File

@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
field_whitelist = { field_whitelist = {
"Dialogue": ["content", "created_at"], "Dialogue": ["content", "created_at"],
"Chunk": ["content", "created_at"], "Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
} }