dev新增短期记忆功能 (#47)

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能
This commit is contained in:
lixinyue11
2026-01-07 16:36:11 +08:00
committed by GitHub
parent 5fe8043ff8
commit bcb3d587a1
9 changed files with 765 additions and 45 deletions

View File

@@ -24,6 +24,7 @@ from . import (
memory_storage_controller,
memory_dashboard_controller,
memory_reflection_controller,
memory_short_term_controller,
api_key_controller,
release_share_controller,
public_share_controller,
@@ -71,6 +72,7 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router)

View File

@@ -0,0 +1,44 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
)
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -7,13 +7,20 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
@@ -96,7 +103,8 @@ class LangChainAgent:
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else []
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
}
)
@@ -137,11 +145,8 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
"""
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
@@ -155,17 +160,18 @@ class LangChainAgent:
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
return messagss_list
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
@@ -205,7 +211,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
@@ -223,11 +228,26 @@ class LangChainAgent:
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory=await self.term_memory_redis_read(end_user_id)
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory=';'.join(history_term_memory)
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try:
@@ -316,10 +336,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
@@ -331,14 +347,24 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id)
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try:

View File

@@ -6,6 +6,7 @@ from .document_model import Document
from .file_model import File
from .generic_file_model import GenericFile
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
from .memory_short_model import ShortTermMemory, LongTermMemory
from .knowledgeshare_model import KnowledgeShare
from .app_model import App
from .agent_app_config_model import AgentConfig
@@ -67,6 +68,8 @@ __all__ = [
"BuiltinToolConfig",
"CustomToolConfig",
"MCPToolConfig",
"ShortTermMemory",
"LongTermMemory",
"ToolExecution",
"ToolType",
"ToolStatus",

View File

@@ -0,0 +1,60 @@
"""
记忆模型 - 短期记忆和长期记忆表
"""
import uuid
import datetime
from sqlalchemy import Column, String, DateTime, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class ShortTermMemory(Base):
"""短期记忆表
用于存储临时的对话记忆,通常保存较短时间
"""
__tablename__ = "memory_short_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 对话内容
messages = Column(Text, nullable=False, comment="用户消息内容")
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
# 搜索开关
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
class LongTermMemory(Base):
"""长期记忆表
用于存储重要的对话记忆,长期保存
"""
__tablename__ = "memory_long_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"

View File

@@ -0,0 +1,503 @@
"""
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import datetime
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class ShortTermMemoryRepository:
"""短期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建短期记忆记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建的短期记忆对象
"""
try:
memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
raise
def count_by_user_id(self,end_user_id: str) -> int:
"""根据ID获取短期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[ShortTermMemory]: 记忆对象如果不存在则返回None
"""
try:
count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
raise
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
"""获取用户最新的短期记忆记录
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认5条
Returns:
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.order_by(ShortTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
raise
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
"""获取用户最近指定小时内的短期记忆记录
Args:
end_user_id: 终端用户ID
hours: 时间范围小时默认24小时
Returns:
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.created_at >= cutoff_time
)
.order_by(ShortTermMemory.created_at.desc())
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的短期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def delete_old_memories(self, days: int = 7) -> int:
"""删除指定天数之前的短期记忆记录
Args:
days: 保留天数默认7天
Returns:
int: 删除的记录数
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.created_at < cutoff_time)
.delete(synchronize_session=False)
)
self.db.commit()
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
return deleted_count
except Exception as e:
self.db.rollback()
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建或更新短期记忆记录
根据 end_user_id、messages 和 aimessages 查找现有记录:
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
- 如果没有找到匹配的记录,则创建新记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建或更新的短期记忆对象
"""
try:
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
query_filters = [
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.messages == messages
]
# 如果 aimessages 不为空,则加入查询条件
if aimessages is not None:
query_filters.append(ShortTermMemory.aimessages == aimessages)
else:
# 如果 aimessages 为 None则查找 aimessages 为 NULL 的记录
query_filters.append(ShortTermMemory.aimessages.is_(None))
# 查找现有记录
existing_memory = (
self.db.query(ShortTermMemory)
.filter(*query_filters)
.first()
)
if existing_memory:
# 更新现有记录
existing_memory.messages = messages
existing_memory.aimessages = aimessages
existing_memory.search_switch = search_switch
existing_memory.retrieved_content = retrieved_content or []
self.db.commit()
self.db.refresh(existing_memory)
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
return existing_memory
else:
# 创建新记录
new_memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
raise
class LongTermMemoryRepository:
"""长期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
"""创建长期记忆记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
LongTermMemory: 创建的长期记忆对象
"""
try:
memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
raise
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
"""根据ID获取长期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[LongTermMemory]: 记忆对象如果不存在则返回None
"""
try:
memory = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.first()
)
if memory:
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
else:
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
"""根据用户ID获取长期记忆记录列表
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认100
offset: 偏移量默认0
Returns:
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_long_term_user_time 优化查询
memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id}{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
raise
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
"""根据内容关键词搜索长期记忆记录
Args:
end_user_id: 终端用户ID
keyword: 搜索关键词
limit: 返回记录数限制默认50
Returns:
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
"""
try:
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
memories = (
self.db.query(LongTermMemory)
.filter(
LongTermMemory.end_user_id == end_user_id,
LongTermMemory.retrieved_content.astext.contains(keyword)
)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}'{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的长期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def count_by_user_id(self, end_user_id: str) -> int:
"""统计用户的长期记忆记录数量
Args:
end_user_id: 终端用户ID
Returns:
int: 记录数量
"""
try:
count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
"""创建或更新长期记忆记录
根据 end_user_id 和 retrieved_content 判断是否需要写入:
- 如果找到相同的 end_user_id 和 retrieved_content则不写入返回 None
- 如果没有找到相同的记录,则创建新记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
"""
try:
retrieved_content = retrieved_content or []
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
existing_memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(100) # 限制查询数量,避免加载过多数据
.all()
)
# 在 Python 中比较 retrieved_content
for memory in existing_memories:
if memory.retrieved_content == retrieved_content:
# 如果找到相同的记录,不写入
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
return None
# 如果没有找到相同的记录,创建新记录
new_memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
raise

View File

@@ -4,6 +4,7 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import datetime
import json
import os
import re
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService
@@ -393,7 +395,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message=message
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
@@ -406,15 +408,15 @@ class MemoryAgentService:
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
# Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id)
@@ -430,7 +432,7 @@ class MemoryAgentService:
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
@@ -442,9 +444,9 @@ class MemoryAgentService:
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -452,7 +454,7 @@ class MemoryAgentService:
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
@@ -475,7 +477,7 @@ class MemoryAgentService:
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
@@ -483,7 +485,7 @@ class MemoryAgentService:
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
@@ -496,7 +498,7 @@ class MemoryAgentService:
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
@@ -506,16 +508,16 @@ class MemoryAgentService:
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
@@ -523,7 +525,7 @@ class MemoryAgentService:
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s")
@@ -532,7 +534,7 @@ class MemoryAgentService:
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
@@ -542,7 +544,7 @@ class MemoryAgentService:
break
else:
continue # No text block found
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
@@ -552,15 +554,15 @@ class MemoryAgentService:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
@@ -577,11 +579,11 @@ class MemoryAgentService:
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
@@ -596,7 +598,31 @@ class MemoryAgentService:
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
statements=list(set(statements))
retrieved_content.append({query:statements})
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs

View File

@@ -0,0 +1,56 @@
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.repositories.memory_short_repository import ShortTermMemoryRepository
api_logger = get_api_logger()
db=next(get_db())
class ShortService:
def __init__(self, end_user_id):
self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_short_databasets(self):
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = []
for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory
messages = memory.messages
aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or []
api_logger.debug(f"Retrieved content: {retrieved_content}")
retrieval_source = []
for item in retrieved_content:
if isinstance(item, dict):
for key, values in item.items():
retrieval_source.append({"query": key, "retrieval": values})
deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误
deep_expanded['answer'] = aimessages
short_result.append(deep_expanded)
return short_result
def get_short_count(self):
short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count
class LongService:
def __init__(self, end_user_id):
self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = []
for long_memory in long_memories:
if long_memory.retrieved_content:
for memory_item in long_memory.retrieved_content:
if isinstance(memory_item, dict):
for key, values in memory_item.items():
long_result.append({"query": key, "retrieval": values})
return long_result

View File

@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
field_whitelist = {
"Dialogue": ["content", "created_at"],
"Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
}