Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-07 17:49:45 +08:00
21 changed files with 1534 additions and 55 deletions

View File

@@ -24,6 +24,7 @@ from . import (
memory_storage_controller, memory_storage_controller,
memory_dashboard_controller, memory_dashboard_controller,
memory_reflection_controller, memory_reflection_controller,
memory_short_term_controller,
api_key_controller, api_key_controller,
release_share_controller, release_share_controller,
public_share_controller, public_share_controller,
@@ -71,6 +72,7 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router) manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router) manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router) manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router) manager_router.include_router(home_page_controller.router)

View File

@@ -0,0 +1,255 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.memory_perceptual_model import PerceptualType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualFilter
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_perceptual_service import MemoryPerceptualService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/perceptual",
tags=["Perceptual Memory System"],
dependencies=[Depends(get_current_user)]
)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
group_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
return success(
data=count_stats,
msg="Memory statistics retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}")
return success(
data=None,
msg="No visual memory available"
)
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
return success(
data=visual_memory,
msg="Latest visual memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}")
return success(
data=None,
msg="No audio memory available"
)
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
return success(
data=audio_memory,
msg="Latest audio memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id)
if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}")
return success(
data=None,
msg="No text memory available"
)
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
return success(
data=text_memory,
msg="Latest text memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
group_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve a timeline of perceptual memories for a user group.
Args:
group_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Timeline data of perceptual memories
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}"
)
try:
query = PerceptualQuerySchema(
filter=PerceptualFilter(type=perceptual_type),
page=page,
page_size=page_size
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
f"returned={len(timeline_data.memories)}"
)
return success(
data=timeline_data.model_dump(),
msg="Perceptual memory timeline retrieved successfully"
)
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"error={str(e)}"
)
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch perceptual memory timeline",
)

View File

@@ -0,0 +1,44 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
)
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -7,13 +7,20 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.db import get_db
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
@@ -96,7 +103,8 @@ class LangChainAgent:
"temperature": temperature, "temperature": temperature,
"streaming": streaming, "streaming": streaming,
"tool_count": len(self.tools), "tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [] "tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
} }
) )
@@ -137,11 +145,8 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
async def term_memory_save(self,messages,end_user_end,aimessages): async def term_memory_save(self,messages,end_user_end,aimessages):
""" '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
end_user_end=f"Term_{end_user_end}" end_user_end=f"Term_{end_user_end}"
print(messages) print(messages)
print(aimessages) print(aimessages)
@@ -155,17 +160,18 @@ class LangChainAgent:
store.delete_duplicate_sessions() store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}') # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id return session_id
async def term_memory_redis_read(self,end_user_end): async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}" end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}') # logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[] messagss_list=[]
retrieved_content=[]
for messages in history: for messages in history:
query = messages.get("Query") query = messages.get("Query")
aimessages = messages.get("Answer") aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}') messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
return messagss_list retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
@@ -205,7 +211,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config # If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id: if actual_config_id is None and end_user_id:
try: try:
from app.db import get_db
from app.services.memory_agent_service import ( from app.services.memory_agent_service import (
get_end_user_connected_config, get_end_user_connected_config,
) )
@@ -223,11 +228,26 @@ class LangChainAgent:
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory=await self.term_memory_redis_read(end_user_id) history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag: if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag": if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory=';'.join(history_term_memory) history_term_memory = ';'.join(history_term_memory)
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try: try:
@@ -316,10 +336,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config # If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id: if actual_config_id is None and end_user_id:
try: try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db()) db = next(get_db())
try: try:
connected_config = get_end_user_connected_config(end_user_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
@@ -331,14 +347,24 @@ class LangChainAgent:
except Exception as e: except Exception as e:
logger.warning(f"Failed to get db session: {e}") logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id) history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag: if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag": if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory) history_term_memory = ';'.join(history_term_memory)
logger.info( retrieved_content = history_term_memory_result[1]
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') db_for_memory = next(get_db())
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, try:
history_term_memory, actual_config_id) repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try: try:

View File

@@ -246,7 +246,7 @@ class AccessHistoryManager:
if not node_data: if not node_data:
return ConsistencyCheckResult.CONSISTENT, None return ConsistencyCheckResult.CONSISTENT, None
access_history = node_data.get('access_history', []) access_history = node_data.get('access_history') or []
last_access_time = node_data.get('last_access_time') last_access_time = node_data.get('last_access_time')
access_count = node_data.get('access_count', 0) access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value') activation_value = node_data.get('activation_value')
@@ -409,7 +409,7 @@ class AccessHistoryManager:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False return False
access_history = node_data.get('access_history', []) access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5) importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据 # 准备修复数据
@@ -530,7 +530,7 @@ class AccessHistoryManager:
Returns: Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段 Dict[str, Any]: 更新数据,包含所有需要更新的字段
""" """
access_history = node_data.get('access_history', []) access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5) importance_score = node_data.get('importance_score', 0.5)
# 追加新的访问时间 # 追加新的访问时间

View File

@@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel):
content_type = info.data.get("content_type") content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData): if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData") raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict): elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object") raise ValueError("When content_type is JSON, data must be of type str")
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is x-www-form-urlencoded, data must be a object")
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str): elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)") raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
return v return v

View File

@@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode):
return {} return {}
case HttpContentType.JSON: case HttpContentType.JSON:
content["json"] = json.loads(self._render_template( content["json"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state self.typed_config.body.data, state
)) ))
case HttpContentType.FROM_DATA: case HttpContentType.FROM_DATA:
data = {} data = {}

View File

@@ -6,6 +6,7 @@ from .document_model import Document
from .file_model import File from .file_model import File
from .generic_file_model import GenericFile from .generic_file_model import GenericFile
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
from .memory_short_model import ShortTermMemory, LongTermMemory
from .knowledgeshare_model import KnowledgeShare from .knowledgeshare_model import KnowledgeShare
from .app_model import App from .app_model import App
from .agent_app_config_model import AgentConfig from .agent_app_config_model import AgentConfig
@@ -67,6 +68,8 @@ __all__ = [
"BuiltinToolConfig", "BuiltinToolConfig",
"CustomToolConfig", "CustomToolConfig",
"MCPToolConfig", "MCPToolConfig",
"ShortTermMemory",
"LongTermMemory",
"ToolExecution", "ToolExecution",
"ToolType", "ToolType",
"ToolStatus", "ToolStatus",

View File

@@ -0,0 +1,40 @@
import datetime
import uuid
from enum import IntEnum
from sqlalchemy import Column, ForeignKey, Integer, DateTime, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base
class PerceptualType(IntEnum):
VISION = 1
AUDIO = 2
TEXT = 3
CONVERSATION = 4
class FileStorageType(IntEnum):
LOCAL = 1
REMOTE = 2
class MemoryPerceptualModel(Base):
__tablename__ = "memory_perceptual"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True)
perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型")
storage_service = Column(Integer, default=0, comment="存储服务类型")
file_path = Column(String, nullable=False, comment="文件路径")
file_name = Column(String, nullable=False, comment="文件名称")
file_ext = Column(String, nullable=False, comment="文件后缀名")
summary = Column(String, comment="摘要")
meta_data = Column(JSONB, comment="元信息")
created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间")

View File

@@ -0,0 +1,60 @@
"""
记忆模型 - 短期记忆和长期记忆表
"""
import uuid
import datetime
from sqlalchemy import Column, String, DateTime, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class ShortTermMemory(Base):
"""短期记忆表
用于存储临时的对话记忆,通常保存较短时间
"""
__tablename__ = "memory_short_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 对话内容
messages = Column(Text, nullable=False, comment="用户消息内容")
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
# 搜索开关
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
class LongTermMemory(Base):
"""长期记忆表
用于存储重要的对话记忆,长期保存
"""
__tablename__ = "memory_long_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"

View File

@@ -0,0 +1,156 @@
import uuid
from datetime import datetime
from typing import List, Tuple, Optional
from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
db_logger = get_db_logger()
class MemoryPerceptualRepository:
"""Data Access Layer for perceptual memory"""
def __init__(self, db: Session):
self.db = db
# ==================== Create and update ====================
def create_perceptual_memory(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
file_path: str,
file_name: str,
file_ext: str,
summary: Optional[str] = None,
meta_data: Optional[dict] = None,
storage_service: FileStorageType = FileStorageType.LOCAL
) -> MemoryPerceptualModel:
"""Create perceptual memory"""
db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, "
f"type={perceptual_type}, file={file_name}")
try:
perceptual_memory = MemoryPerceptualModel(
end_user_id=end_user_id,
perceptual_type=perceptual_type,
storage_service=storage_service,
file_path=file_path,
file_name=file_name,
file_ext=file_ext,
summary=summary,
meta_data=meta_data,
created_time=datetime.now()
)
self.db.add(perceptual_memory)
self.db.flush()
db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}")
return perceptual_memory
except Exception as e:
db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}")
raise
# ==================== Query ====================
def get_count_by_user_id(
self,
end_user_id: uuid.UUID,
):
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}")
try:
count = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id
).count()
return count
except Exception as e:
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
raise
def get_count_by_type(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
):
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}")
try:
count = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id,
MemoryPerceptualModel.perceptual_type == perceptual_type
).count()
return count
except Exception as e:
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
raise
def get_timeline(
self,
end_user_id: uuid.UUID,
query: PerceptualQuerySchema
) -> Tuple[int, List[MemoryPerceptualModel]]:
"""Get the timeline of a user's perceptual memories"""
db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}")
try:
base_query = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id
)
if query.filter.type is not None:
base_query = base_query.filter(
MemoryPerceptualModel.perceptual_type == query.filter.type
)
total_count = base_query.count()
memories = base_query.order_by(
desc(MemoryPerceptualModel.created_time)
).offset(
(query.page - 1) * query.page_size
).limit(query.page_size).all()
db_logger.info(
f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}")
return total_count, memories
except Exception as e:
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
raise
def get_by_type(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
limit: int = 10,
offset: int = 0
) -> List[MemoryPerceptualModel]:
"""Get memories by perceptual type"""
db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}")
try:
memories = self.db.query(MemoryPerceptualModel).filter(
and_(
MemoryPerceptualModel.end_user_id == end_user_id,
MemoryPerceptualModel.perceptual_type == perceptual_type
)
).order_by(
desc(MemoryPerceptualModel.created_time)
).offset(offset).limit(limit).all()
db_logger.debug(f"Query by type succeeded: count={len(memories)}")
return memories
except Exception as e:
db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, "
f"type={perceptual_type} - {str(e)}")
raise

View File

@@ -0,0 +1,503 @@
"""
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import datetime
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class ShortTermMemoryRepository:
"""短期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建短期记忆记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建的短期记忆对象
"""
try:
memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
raise
def count_by_user_id(self,end_user_id: str) -> int:
"""根据ID获取短期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[ShortTermMemory]: 记忆对象如果不存在则返回None
"""
try:
count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
raise
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
"""获取用户最新的短期记忆记录
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认5条
Returns:
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.order_by(ShortTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
raise
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
"""获取用户最近指定小时内的短期记忆记录
Args:
end_user_id: 终端用户ID
hours: 时间范围小时默认24小时
Returns:
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.created_at >= cutoff_time
)
.order_by(ShortTermMemory.created_at.desc())
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的短期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def delete_old_memories(self, days: int = 7) -> int:
"""删除指定天数之前的短期记忆记录
Args:
days: 保留天数默认7天
Returns:
int: 删除的记录数
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.created_at < cutoff_time)
.delete(synchronize_session=False)
)
self.db.commit()
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
return deleted_count
except Exception as e:
self.db.rollback()
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建或更新短期记忆记录
根据 end_user_id、messages 和 aimessages 查找现有记录:
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
- 如果没有找到匹配的记录,则创建新记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建或更新的短期记忆对象
"""
try:
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
query_filters = [
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.messages == messages
]
# 如果 aimessages 不为空,则加入查询条件
if aimessages is not None:
query_filters.append(ShortTermMemory.aimessages == aimessages)
else:
# 如果 aimessages 为 None则查找 aimessages 为 NULL 的记录
query_filters.append(ShortTermMemory.aimessages.is_(None))
# 查找现有记录
existing_memory = (
self.db.query(ShortTermMemory)
.filter(*query_filters)
.first()
)
if existing_memory:
# 更新现有记录
existing_memory.messages = messages
existing_memory.aimessages = aimessages
existing_memory.search_switch = search_switch
existing_memory.retrieved_content = retrieved_content or []
self.db.commit()
self.db.refresh(existing_memory)
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
return existing_memory
else:
# 创建新记录
new_memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
raise
class LongTermMemoryRepository:
"""长期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
"""创建长期记忆记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
LongTermMemory: 创建的长期记忆对象
"""
try:
memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
raise
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
"""根据ID获取长期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[LongTermMemory]: 记忆对象如果不存在则返回None
"""
try:
memory = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.first()
)
if memory:
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
else:
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
"""根据用户ID获取长期记忆记录列表
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认100
offset: 偏移量默认0
Returns:
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_long_term_user_time 优化查询
memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id}{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
raise
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
"""根据内容关键词搜索长期记忆记录
Args:
end_user_id: 终端用户ID
keyword: 搜索关键词
limit: 返回记录数限制默认50
Returns:
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
"""
try:
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
memories = (
self.db.query(LongTermMemory)
.filter(
LongTermMemory.end_user_id == end_user_id,
LongTermMemory.retrieved_content.astext.contains(keyword)
)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}'{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的长期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def count_by_user_id(self, end_user_id: str) -> int:
"""统计用户的长期记忆记录数量
Args:
end_user_id: 终端用户ID
Returns:
int: 记录数量
"""
try:
count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
"""创建或更新长期记忆记录
根据 end_user_id 和 retrieved_content 判断是否需要写入:
- 如果找到相同的 end_user_id 和 retrieved_content则不写入返回 None
- 如果没有找到相同的记录,则创建新记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
"""
try:
retrieved_content = retrieved_content or []
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
existing_memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(100) # 限制查询数量,避免加载过多数据
.all()
)
# 在 Python 中比较 retrieved_content
for memory in existing_memories:
if memory.retrieved_content == retrieved_content:
# 如果找到相同的记录,不写入
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
return None
# 如果没有找到相同的记录,创建新记录
new_memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
raise

View File

@@ -722,7 +722,12 @@ SET m += {
chunk_ids: summary.chunk_ids, chunk_ids: summary.chunk_ids,
content: summary.content, content: summary.content,
summary_embedding: summary.summary_embedding, summary_embedding: summary.summary_embedding,
config_id: summary.config_id config_id: summary.config_id,
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,
activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END,
access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END,
last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END,
access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END
} }
RETURN m.id AS uuid RETURN m.id AS uuid
""" """

View File

@@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
# 处理 ACT-R 属性 - 确保字段存在且有默认值 # 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5) n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value') n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', []) n['access_history'] = n.get('access_history') or []
n['last_access_time'] = n.get('last_access_time') n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0) n['access_count'] = n.get('access_count', 0)

View File

@@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
# 处理 ACT-R 属性 - 确保字段存在且有默认值 # 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5) n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value') n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', []) n['access_history'] = n.get('access_history') or []
n['last_access_time'] = n.get('last_access_time') n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0) n['access_count'] = n.get('access_count', 0)

View File

@@ -0,0 +1,133 @@
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
class PerceptualFilter(BaseModel):
type: PerceptualType | None = Field(
default=None,
description="Perceptual type used for filtering the query; optional"
)
class PerceptualQuerySchema(BaseModel):
filter: PerceptualFilter = Field(
default_factory=lambda: PerceptualFilter(),
description="Query filter containing perceptual type criteria"
)
page: int = Field(
default=1,
ge=1,
description="Page number for pagination, starting from 1"
)
page_size: int = Field(
default=10,
ge=1,
le=100,
description="Number of records per page, range 1-100"
)
class PerceptualMemoryItem(BaseModel):
"""感知记忆项"""
id: uuid.UUID = Field(..., description="Unique memory ID")
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
file_path: str = Field(..., description="File path in the storage service")
file_name: str = Field(..., description="File name")
summary: Optional[str] = Field(None, description="摘要")
storage_type: FileStorageType = Field(..., description="Storage type for file")
created_time: Optional[datetime] = Field(None, description="创建时间")
class Config:
from_attributes = True
class PerceptualTimelineResponse(BaseModel):
"""感知记忆时间线响应"""
total: int = Field(..., description="总数量")
page: int = Field(..., description="当前页码")
page_size: int = Field(..., description="每页大小")
total_pages: int = Field(..., description="总页数")
memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表")
class Config:
from_attributes = True
# --------------------------
# TODO: FileMetaData
# --------------------------
class Identity(BaseModel):
title: str
filename: str
source: str # upload | crawl | system
author: Optional[str] = None
class Semantic(BaseModel):
topic: str
domain: str
difficulty: str # beginner | intermediate | advanced
intent: str # informative | instructional | promotional
sentiment: str # positive | neutral | negative
class Content(BaseModel):
summary: str
keywords: list[str]
topic: str
domain: str
class Usage(BaseModel):
target_audience: list[str]
use_cases: list[str]
class Stats(BaseModel):
duration_sec: Optional[int] = None
char_count: int
word_count: int
class Processing(BaseModel):
transcribed: bool
ocr_applied: bool
chunked: bool
vectorized: bool
embedding_model: Optional[str] = None
class VideoModal(BaseModel):
scene: list[str]
class AudioModal(BaseModel):
speaker_count: int
class TextModal(BaseModel):
section_count: int
class Asset(BaseModel):
type: str
modality: str # text | audio | video
format: str # docx | mp3 | mp4
language: str
encoding: str
identity: Identity
semantic: Semantic
content: Content
usage: Usage
stats: Stats
processing: Processing
created_at: str
modalities: AudioModal | TextModal | VideoModal

View File

@@ -4,6 +4,7 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services, Handles business logic for memory agent operations including read/write services,
health checks, and message type classification. health checks, and message type classification.
""" """
import datetime
import json import json
import os import os
import re import re
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -393,7 +395,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message=message
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
try: try:
@@ -406,15 +408,15 @@ class MemoryAgentService:
raise # Re-raise our specific error raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}") logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}") logger.info(f"Read operation for group {group_id} with config_id {config_id}")
# 导入审计日志记录器 # 导入审计日志记录器
try: try:
from app.core.memory.utils.log.audit_logger import audit_logger from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError: except ImportError:
audit_logger = None audit_logger = None
# Get group lock to prevent concurrent processing # Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id) group_lock = self.get_group_lock(group_id)
@@ -430,7 +432,7 @@ class MemoryAgentService:
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
@@ -442,9 +444,9 @@ class MemoryAgentService:
duration=duration, duration=duration,
error=error_msg error=error_msg
) )
raise ValueError(error_msg) raise ValueError(error_msg)
# Step 2: Prepare history # Step 2: Prepare history
history.append({"role": "user", "content": message}) history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -452,7 +454,7 @@ class MemoryAgentService:
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config() mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config) client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session: async with client.session('data_flow') as session:
logger.debug("Connected to MCP Server: data_flow") logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session) tools = await load_mcp_tools(session)
@@ -475,7 +477,7 @@ class MemoryAgentService:
# Capture any errors from the state # Capture any errors from the state
if event.get('errors'): if event.get('errors'):
workflow_errors.extend(event.get('errors', [])) workflow_errors.extend(event.get('errors', []))
for msg in messages: for msg in messages:
msg_content = msg.content msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "") msg_role = msg.__class__.__name__.lower().replace("message", "")
@@ -483,7 +485,7 @@ class MemoryAgentService:
"role": msg_role, "role": msg_role,
"content": msg_content "content": msg_content
}) })
# Extract intermediate outputs # Extract intermediate outputs
if hasattr(msg, 'content'): if hasattr(msg, 'content'):
try: try:
@@ -496,7 +498,7 @@ class MemoryAgentService:
break break
else: else:
continue # No text block found continue # No text block found
# Try to parse content as JSON # Try to parse content as JSON
if isinstance(content_to_parse, str): if isinstance(content_to_parse, str):
try: try:
@@ -506,16 +508,16 @@ class MemoryAgentService:
if '_intermediate' in parsed: if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate'] intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve) # Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed: if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']: for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
@@ -523,7 +525,7 @@ class MemoryAgentService:
pass pass
except Exception as e: except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}") logger.debug(f"Failed to extract intermediate output: {e}")
workflow_duration = time.time() - start workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s") logger.info(f"Read graph workflow completed in {workflow_duration}s")
@@ -532,7 +534,7 @@ class MemoryAgentService:
for messages in outputs: for messages in outputs:
if messages['role'] == 'tool': if messages['role'] == 'tool':
message = messages['content'] message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}] # Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list): if isinstance(message, list):
# Extract text from MCP content blocks # Extract text from MCP content blocks
@@ -542,7 +544,7 @@ class MemoryAgentService:
break break
else: else:
continue # No text block found continue # No text block found
try: try:
parsed = json.loads(message) if isinstance(message, str) else message parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict): if isinstance(parsed, dict):
@@ -552,15 +554,15 @@ class MemoryAgentService:
final_answer = summary_result final_answer = summary_result
except (json.JSONDecodeError, ValueError): except (json.JSONDecodeError, ValueError):
pass pass
# 记录成功的操作 # 记录成功的操作
total_duration = time.time() - start_time total_duration = time.time() - start_time
# Check for workflow errors # Check for workflow errors
if workflow_errors: if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}") logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger: if audit_logger:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
@@ -577,11 +579,11 @@ class MemoryAgentService:
"errors": workflow_errors "errors": workflow_errors
} }
) )
# Raise error if no answer was produced # Raise error if no answer was produced
if not final_answer: if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}") raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors: if audit_logger and not workflow_errors:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
@@ -596,7 +598,31 @@ class MemoryAgentService:
"has_answer": bool(final_answer) "has_answer": bool(final_answer)
} }
) )
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
statements=list(set(statements))
retrieved_content.append({query:statements})
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return { return {
"answer": final_answer, "answer": final_answer,
"intermediate_outputs": intermediate_outputs "intermediate_outputs": intermediate_outputs

View File

@@ -0,0 +1,166 @@
import uuid
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal
)
business_logger = get_business_logger()
class MemoryPerceptualService:
def __init__(self, db: Session):
self.db = db
self.repository = MemoryPerceptualRepository(db)
def get_memory_count(self, end_user_id: uuid.UUID) -> Dict[str, Any]:
"""Retrieve perceptual memory statistics for a user."""
business_logger.info(f"Fetching perceptual memory statistics: end_user_id={end_user_id}")
try:
total_count = self.repository.get_count_by_user_id(end_user_id=end_user_id)
vision_count = self.repository.get_count_by_type(end_user_id, PerceptualType.VISION)
audio_count = self.repository.get_count_by_type(end_user_id, PerceptualType.AUDIO)
text_count = self.repository.get_count_by_type(end_user_id, PerceptualType.TEXT)
conversation_count = self.repository.get_count_by_type(end_user_id, PerceptualType.CONVERSATION)
stats = {
"total": total_count,
"by_type": {
"vision": vision_count,
"audio": audio_count,
"text": text_count,
"conversation": conversation_count
}
}
business_logger.info(f"Memory statistics fetched successfully: total={total_count}")
return stats
except Exception as e:
business_logger.error(f"Failed to fetch memory statistics: {str(e)}")
raise BusinessException(f"Failed to fetch memory statistics: {str(e)}", BizCode.DB_ERROR)
def _get_latest_memory_by_type(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType
) -> Optional[dict[str, Any]]:
"""Internal helper to retrieve the latest memory by type."""
business_logger.info(f"Fetching latest {perceptual_type.name.lower()} memory: end_user_id={end_user_id}")
try:
memories = self.repository.get_by_type(
end_user_id=end_user_id,
perceptual_type=perceptual_type,
limit=1,
offset=0
)
if not memories:
business_logger.info(f"No {perceptual_type.name.lower()} memory found: end_user_id={end_user_id}")
return None
memory = memories[0]
meta_data = memory.meta_data or {}
modalities = meta_data.get("modalities")
content = meta_data.get("content")
if not modalities:
raise BusinessException(f"Modalities not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
if not content:
raise BusinessException(f"Content not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
content = Content(**content)
match perceptual_type:
case PerceptualType.VISION:
modal = VideoModal(**modalities)
case PerceptualType.AUDIO:
modal = AudioModal(**modalities)
case PerceptualType.TEXT:
modal = TextModal(**modalities)
case _:
raise BusinessException("Unsupported perceptual type", BizCode.DB_ERROR)
detail = modal.model_dump()
result = {
"id": str(memory.id),
"file_name": memory.file_name,
"file_path": memory.file_path,
"storage_type": memory.storage_service,
"summary": memory.summary,
"keywords": content.keywords,
"topic": content.topic,
"domain": content.domain,
"created_time": memory.created_time.isoformat() if memory.created_time else None,
**detail
}
business_logger.info(
f"Latest {perceptual_type.name.lower()} memory retrieved successfully: file={memory.file_name}")
return result
except Exception as e:
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR)
def get_latest_visual_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
return self._get_latest_memory_by_type(end_user_id, PerceptualType.VISION)
def get_latest_audio_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
return self._get_latest_memory_by_type(end_user_id, PerceptualType.AUDIO)
def get_latest_text_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
return self._get_latest_memory_by_type(end_user_id, PerceptualType.TEXT)
def get_time_line(self, end_user_id: uuid.UUID, query: PerceptualQuerySchema) -> PerceptualTimelineResponse:
"""Retrieve a timeline of perceptual memories for a user."""
business_logger.info(f"Fetching perceptual memory timeline: "
f"end_user_id={end_user_id}, filter={query.filter}")
try:
if query.page < 1:
raise BusinessException("Page number must be greater than 0", BizCode.INVALID_PARAMETER)
if query.page_size < 1 or query.page_size > 100:
raise BusinessException("Page size must be between 1 and 100", BizCode.INVALID_PARAMETER)
total_count, memories = self.repository.get_timeline(end_user_id, query)
memory_items = []
for memory in memories:
memory_item = PerceptualMemoryItem(
id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type),
file_path=memory.file_path,
file_name=memory.file_name,
summary=memory.summary,
created_time=memory.created_time,
storage_type=FileStorageType(memory.storage_service),
)
memory_items.append(memory_item)
timeline_response = PerceptualTimelineResponse(
total=total_count,
page=query.page,
page_size=query.page_size,
total_pages=(total_count + query.page_size - 1) // query.page_size,
memories=memory_items
)
business_logger.info(f"Perceptual memory timeline retrieved successfully: "
f"total={total_count}, returned={len(memories)}")
return timeline_response
except BusinessException:
raise
except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)

View File

@@ -0,0 +1,56 @@
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.repositories.memory_short_repository import ShortTermMemoryRepository
api_logger = get_api_logger()
db=next(get_db())
class ShortService:
def __init__(self, end_user_id):
self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_short_databasets(self):
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = []
for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory
messages = memory.messages
aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or []
api_logger.debug(f"Retrieved content: {retrieved_content}")
retrieval_source = []
for item in retrieved_content:
if isinstance(item, dict):
for key, values in item.items():
retrieval_source.append({"query": key, "retrieval": values})
deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误
deep_expanded['answer'] = aimessages
short_result.append(deep_expanded)
return short_result
def get_short_count(self):
short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count
class LongService:
def __init__(self, end_user_id):
self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = []
for long_memory in long_memories:
if long_memory.retrieved_content:
for memory_item in long_memory.retrieved_content:
if isinstance(memory_item, dict):
for key, values in memory_item.items():
long_result.append({"query": key, "retrieval": values})
return long_result

View File

@@ -166,6 +166,8 @@ class PromptOptimizerService:
model_config = self.get_model_config(tenant_id, model_id) model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance # Create LLM instance
api_config: ModelApiKey = model_config.api_keys[0] api_config: ModelApiKey = model_config.api_keys[0]
llm = RedBearLLM(RedBearModelConfig( llm = RedBearLLM(RedBearModelConfig(
@@ -203,7 +205,6 @@ class PromptOptimizerService:
messages.extend(session_history[:-1]) # last message is current message messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)]) messages.extend([(RoleType.USER.value, rendered_user_message)])
logger.info(f"Prompt optimization message: {messages}")
buffer = "" buffer = ""
prompt_started = False prompt_started = False
prompt_finished = False prompt_finished = False
@@ -250,6 +251,7 @@ class PromptOptimizerService:
content=desc content=desc
) )
variables = self.parser_prompt_variables(optim_result.get("prompt")) variables = self.parser_prompt_variables(optim_result.get("prompt"))
logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}")
yield {"desc": optim_result.get("desc"), "variables": variables} yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod @staticmethod

View File

@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
field_whitelist = { field_whitelist = {
"Dialogue": ["content", "created_at"], "Dialogue": ["content", "created_at"],
"Chunk": ["content", "created_at"], "Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
} }