Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -24,6 +24,7 @@ from . import (
|
|||||||
memory_storage_controller,
|
memory_storage_controller,
|
||||||
memory_dashboard_controller,
|
memory_dashboard_controller,
|
||||||
memory_reflection_controller,
|
memory_reflection_controller,
|
||||||
|
memory_short_term_controller,
|
||||||
api_key_controller,
|
api_key_controller,
|
||||||
release_share_controller,
|
release_share_controller,
|
||||||
public_share_controller,
|
public_share_controller,
|
||||||
@@ -71,6 +72,7 @@ manager_router.include_router(emotion_controller.router)
|
|||||||
manager_router.include_router(emotion_config_controller.router)
|
manager_router.include_router(emotion_config_controller.router)
|
||||||
manager_router.include_router(prompt_optimizer_controller.router)
|
manager_router.include_router(prompt_optimizer_controller.router)
|
||||||
manager_router.include_router(memory_reflection_controller.router)
|
manager_router.include_router(memory_reflection_controller.router)
|
||||||
|
manager_router.include_router(memory_short_term_controller.router)
|
||||||
manager_router.include_router(tool_controller.router)
|
manager_router.include_router(tool_controller.router)
|
||||||
manager_router.include_router(memory_forget_controller.router)
|
manager_router.include_router(memory_forget_controller.router)
|
||||||
manager_router.include_router(home_page_controller.router)
|
manager_router.include_router(home_page_controller.router)
|
||||||
|
|||||||
255
api/app/controllers/memory_perceptual_controller.py
Normal file
255
api/app/controllers/memory_perceptual_controller.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models import User
|
||||||
|
from app.models.memory_perceptual_model import PerceptualType
|
||||||
|
from app.schemas.memory_perceptual_schema import (
|
||||||
|
PerceptualQuerySchema,
|
||||||
|
PerceptualFilter
|
||||||
|
)
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/memory/perceptual",
|
||||||
|
tags=["Perceptual Memory System"],
|
||||||
|
dependencies=[Depends(get_current_user)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||||
|
def get_memory_count(
|
||||||
|
group_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Retrieve perceptual memory statistics for a user group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: ID of the user group (usually end_user_id in this context)
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Response containing memory count statistics
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = MemoryPerceptualService(db)
|
||||||
|
count_stats = service.get_memory_count(group_id)
|
||||||
|
|
||||||
|
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=count_stats,
|
||||||
|
msg="Memory statistics retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg="Failed to fetch memory statistics",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||||
|
def get_last_visual_memory(
|
||||||
|
group_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Retrieve the most recent VISION-type memory for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: ID of the user group
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Metadata of the latest visual memory
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = MemoryPerceptualService(db)
|
||||||
|
visual_memory = service.get_latest_visual_memory(group_id)
|
||||||
|
|
||||||
|
if visual_memory is None:
|
||||||
|
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg="No visual memory available"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=visual_memory,
|
||||||
|
msg="Latest visual memory retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg="Failed to fetch latest visual memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||||
|
def get_last_memory_listen(
|
||||||
|
group_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: ID of the user group
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Metadata of the latest audio memory
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = MemoryPerceptualService(db)
|
||||||
|
audio_memory = service.get_latest_audio_memory(group_id)
|
||||||
|
|
||||||
|
if audio_memory is None:
|
||||||
|
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg="No audio memory available"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=audio_memory,
|
||||||
|
msg="Latest audio memory retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg="Failed to fetch latest audio memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||||
|
def get_last_text_memory(
|
||||||
|
group_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Retrieve the most recent TEXT-type memory for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: ID of the user group
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Metadata of the latest text memory
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用服务层获取最近的文本记忆
|
||||||
|
service = MemoryPerceptualService(db)
|
||||||
|
text_memory = service.get_latest_text_memory(group_id)
|
||||||
|
|
||||||
|
if text_memory is None:
|
||||||
|
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg="No text memory available"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=text_memory,
|
||||||
|
msg="Latest text memory retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg="Failed to fetch latest text memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||||
|
def get_memory_time_line(
|
||||||
|
group_id: uuid.UUID,
|
||||||
|
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Retrieve a timeline of perceptual memories for a user group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: ID of the user group
|
||||||
|
perceptual_type: Optional filter for perceptual type
|
||||||
|
page: Page number for pagination
|
||||||
|
page_size: Number of items per page
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Timeline data of perceptual memories
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||||
|
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = PerceptualQuerySchema(
|
||||||
|
filter=PerceptualFilter(type=perceptual_type),
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
service = MemoryPerceptualService(db)
|
||||||
|
timeline_data = service.get_time_line(group_id, query)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||||
|
f"returned={len(timeline_data.memories)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=timeline_data.model_dump(),
|
||||||
|
msg="Perceptual memory timeline retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(
|
||||||
|
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||||
|
f"error={str(e)}"
|
||||||
|
)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg="Failed to fetch perceptual memory timeline",
|
||||||
|
)
|
||||||
44
api/app/controllers/memory_short_term_controller.py
Normal file
44
api/app/controllers/memory_short_term_controller.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
|
||||||
|
from app.services.memory_storage_service import search_entity
|
||||||
|
from app.services.memory_short_service import ShortService,LongService
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
load_dotenv()
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/memory/short",
|
||||||
|
tags=["Memory"],
|
||||||
|
)
|
||||||
|
@router.get("/short_term")
|
||||||
|
async def short_term_configs(
|
||||||
|
end_user_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
# 获取短期记忆数据
|
||||||
|
short_term=ShortService(end_user_id)
|
||||||
|
short_result=short_term.get_short_databasets()
|
||||||
|
short_count=short_term.get_short_count()
|
||||||
|
|
||||||
|
long_term=LongService(end_user_id)
|
||||||
|
long_result=long_term.get_long_databasets()
|
||||||
|
|
||||||
|
entity_result = await search_entity(end_user_id)
|
||||||
|
result = {
|
||||||
|
'short_term': short_result,
|
||||||
|
'long_term': long_result,
|
||||||
|
'entity': entity_result.get('num', 0),
|
||||||
|
"retrieval_number":short_count,
|
||||||
|
"long_term_number":len(long_result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||||
|
|
||||||
@@ -7,13 +7,20 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.services.memory_agent_service import (
|
||||||
|
get_end_user_connected_config,
|
||||||
|
)
|
||||||
from app.services.memory_konwledges_server import write_rag
|
from app.services.memory_konwledges_server import write_rag
|
||||||
from app.services.task_service import get_task_memory_write_result
|
from app.services.task_service import get_task_memory_write_result
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
@@ -96,7 +103,8 @@ class LangChainAgent:
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
|
"tool_count": len(self.tools)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,11 +145,8 @@ class LangChainAgent:
|
|||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||||
"""
|
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||||
短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j
|
|
||||||
"""
|
|
||||||
end_user_end=f"Term_{end_user_end}"
|
end_user_end=f"Term_{end_user_end}"
|
||||||
print(messages)
|
print(messages)
|
||||||
print(aimessages)
|
print(aimessages)
|
||||||
@@ -155,17 +160,18 @@ class LangChainAgent:
|
|||||||
store.delete_duplicate_sessions()
|
store.delete_duplicate_sessions()
|
||||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def term_memory_redis_read(self,end_user_end):
|
async def term_memory_redis_read(self,end_user_end):
|
||||||
end_user_end = f"Term_{end_user_end}"
|
end_user_end = f"Term_{end_user_end}"
|
||||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||||
messagss_list=[]
|
messagss_list=[]
|
||||||
|
retrieved_content=[]
|
||||||
for messages in history:
|
for messages in history:
|
||||||
query = messages.get("Query")
|
query = messages.get("Query")
|
||||||
aimessages = messages.get("Answer")
|
aimessages = messages.get("Answer")
|
||||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||||
return messagss_list
|
retrieved_content.append({query: aimessages})
|
||||||
|
return messagss_list,retrieved_content
|
||||||
|
|
||||||
|
|
||||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||||
@@ -205,7 +211,6 @@ class LangChainAgent:
|
|||||||
# If config_id is None, try to get from end_user's connected config
|
# If config_id is None, try to get from end_user's connected config
|
||||||
if actual_config_id is None and end_user_id:
|
if actual_config_id is None and end_user_id:
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
@@ -223,11 +228,26 @@ class LangChainAgent:
|
|||||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
|
||||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
history_term_memory = history_term_memory_result[0]
|
||||||
|
db_for_memory = next(get_db())
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||||
history_term_memory=';'.join(history_term_memory)
|
history_term_memory = ';'.join(history_term_memory)
|
||||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
retrieved_content = history_term_memory_result[1]
|
||||||
|
print(retrieved_content)
|
||||||
|
# 为长期记忆操作获取新的数据库连接
|
||||||
|
try:
|
||||||
|
repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
repo.upsert(end_user_id, retrieved_content)
|
||||||
|
logger.info(
|
||||||
|
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db_for_memory.close()
|
||||||
|
|
||||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||||
try:
|
try:
|
||||||
@@ -316,10 +336,6 @@ class LangChainAgent:
|
|||||||
# If config_id is None, try to get from end_user's connected config
|
# If config_id is None, try to get from end_user's connected config
|
||||||
if actual_config_id is None and end_user_id:
|
if actual_config_id is None and end_user_id:
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
@@ -331,14 +347,24 @@ class LangChainAgent:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
|
|
||||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
history_term_memory = history_term_memory_result[0]
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||||
history_term_memory = ';'.join(history_term_memory)
|
history_term_memory = ';'.join(history_term_memory)
|
||||||
logger.info(
|
retrieved_content = history_term_memory_result[1]
|
||||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
db_for_memory = next(get_db())
|
||||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
try:
|
||||||
history_term_memory, actual_config_id)
|
repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
repo.upsert(end_user_id, retrieved_content)
|
||||||
|
logger.info(
|
||||||
|
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||||
|
history_term_memory, actual_config_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write to long term memory: {e}")
|
||||||
|
finally:
|
||||||
|
db_for_memory.close()
|
||||||
|
|
||||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class AccessHistoryManager:
|
|||||||
if not node_data:
|
if not node_data:
|
||||||
return ConsistencyCheckResult.CONSISTENT, None
|
return ConsistencyCheckResult.CONSISTENT, None
|
||||||
|
|
||||||
access_history = node_data.get('access_history', [])
|
access_history = node_data.get('access_history') or []
|
||||||
last_access_time = node_data.get('last_access_time')
|
last_access_time = node_data.get('last_access_time')
|
||||||
access_count = node_data.get('access_count', 0)
|
access_count = node_data.get('access_count', 0)
|
||||||
activation_value = node_data.get('activation_value')
|
activation_value = node_data.get('activation_value')
|
||||||
@@ -409,7 +409,7 @@ class AccessHistoryManager:
|
|||||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
access_history = node_data.get('access_history', [])
|
access_history = node_data.get('access_history') or []
|
||||||
importance_score = node_data.get('importance_score', 0.5)
|
importance_score = node_data.get('importance_score', 0.5)
|
||||||
|
|
||||||
# 准备修复数据
|
# 准备修复数据
|
||||||
@@ -530,7 +530,7 @@ class AccessHistoryManager:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||||
"""
|
"""
|
||||||
access_history = node_data.get('access_history', [])
|
access_history = node_data.get('access_history') or []
|
||||||
importance_score = node_data.get('importance_score', 0.5)
|
importance_score = node_data.get('importance_score', 0.5)
|
||||||
|
|
||||||
# 追加新的访问时间
|
# 追加新的访问时间
|
||||||
|
|||||||
@@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel):
|
|||||||
content_type = info.data.get("content_type")
|
content_type = info.data.get("content_type")
|
||||||
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
||||||
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
||||||
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||||
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object")
|
raise ValueError("When content_type is JSON, data must be of type str")
|
||||||
|
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||||
|
raise ValueError("When content_type is x-www-form-urlencoded, data must be a object")
|
||||||
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
|
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
|
||||||
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
|
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
|
||||||
return v
|
return v
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
return {}
|
return {}
|
||||||
case HttpContentType.JSON:
|
case HttpContentType.JSON:
|
||||||
content["json"] = json.loads(self._render_template(
|
content["json"] = json.loads(self._render_template(
|
||||||
json.dumps(self.typed_config.body.data), state
|
self.typed_config.body.data, state
|
||||||
))
|
))
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .document_model import Document
|
|||||||
from .file_model import File
|
from .file_model import File
|
||||||
from .generic_file_model import GenericFile
|
from .generic_file_model import GenericFile
|
||||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
|
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
|
||||||
|
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||||
from .knowledgeshare_model import KnowledgeShare
|
from .knowledgeshare_model import KnowledgeShare
|
||||||
from .app_model import App
|
from .app_model import App
|
||||||
from .agent_app_config_model import AgentConfig
|
from .agent_app_config_model import AgentConfig
|
||||||
@@ -67,6 +68,8 @@ __all__ = [
|
|||||||
"BuiltinToolConfig",
|
"BuiltinToolConfig",
|
||||||
"CustomToolConfig",
|
"CustomToolConfig",
|
||||||
"MCPToolConfig",
|
"MCPToolConfig",
|
||||||
|
"ShortTermMemory",
|
||||||
|
"LongTermMemory",
|
||||||
"ToolExecution",
|
"ToolExecution",
|
||||||
"ToolType",
|
"ToolType",
|
||||||
"ToolStatus",
|
"ToolStatus",
|
||||||
|
|||||||
40
api/app/models/memory_perceptual_model.py
Normal file
40
api/app/models/memory_perceptual_model.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
from sqlalchemy import Column, ForeignKey, Integer, DateTime, String
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualType(IntEnum):
|
||||||
|
VISION = 1
|
||||||
|
AUDIO = 2
|
||||||
|
TEXT = 3
|
||||||
|
CONVERSATION = 4
|
||||||
|
|
||||||
|
|
||||||
|
class FileStorageType(IntEnum):
|
||||||
|
LOCAL = 1
|
||||||
|
REMOTE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryPerceptualModel(Base):
|
||||||
|
__tablename__ = "memory_perceptual"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True)
|
||||||
|
|
||||||
|
perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型")
|
||||||
|
|
||||||
|
storage_service = Column(Integer, default=0, comment="存储服务类型")
|
||||||
|
file_path = Column(String, nullable=False, comment="文件路径")
|
||||||
|
file_name = Column(String, nullable=False, comment="文件名称")
|
||||||
|
file_ext = Column(String, nullable=False, comment="文件后缀名")
|
||||||
|
|
||||||
|
summary = Column(String, comment="摘要")
|
||||||
|
meta_data = Column(JSONB, comment="元信息")
|
||||||
|
|
||||||
|
created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||||
60
api/app/models/memory_short_model.py
Normal file
60
api/app/models/memory_short_model.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
记忆模型 - 短期记忆和长期记忆表
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from sqlalchemy import Column, String, DateTime, Text, JSON
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ShortTermMemory(Base):
|
||||||
|
"""短期记忆表
|
||||||
|
|
||||||
|
用于存储临时的对话记忆,通常保存较短时间
|
||||||
|
"""
|
||||||
|
__tablename__ = "memory_short_term"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
|
||||||
|
|
||||||
|
# 用户信息
|
||||||
|
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
|
||||||
|
|
||||||
|
# 对话内容
|
||||||
|
messages = Column(Text, nullable=False, comment="用户消息内容")
|
||||||
|
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
|
||||||
|
|
||||||
|
# 搜索开关
|
||||||
|
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
|
||||||
|
|
||||||
|
# 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}]
|
||||||
|
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
|
||||||
|
|
||||||
|
|
||||||
|
class LongTermMemory(Base):
|
||||||
|
"""长期记忆表
|
||||||
|
|
||||||
|
用于存储重要的对话记忆,长期保存
|
||||||
|
"""
|
||||||
|
__tablename__ = "memory_long_term"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
|
||||||
|
|
||||||
|
# 用户信息
|
||||||
|
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
|
||||||
|
|
||||||
|
# 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}]
|
||||||
|
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
|
||||||
156
api/app/repositories/memory_perceptual_repository.py
Normal file
156
api/app/repositories/memory_perceptual_repository.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import and_, desc
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
|
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
|
||||||
|
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
|
||||||
|
|
||||||
|
db_logger = get_db_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryPerceptualRepository:
|
||||||
|
"""Data Access Layer for perceptual memory"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
# ==================== Create and update ====================
|
||||||
|
def create_perceptual_memory(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
perceptual_type: PerceptualType,
|
||||||
|
file_path: str,
|
||||||
|
file_name: str,
|
||||||
|
file_ext: str,
|
||||||
|
summary: Optional[str] = None,
|
||||||
|
meta_data: Optional[dict] = None,
|
||||||
|
storage_service: FileStorageType = FileStorageType.LOCAL
|
||||||
|
|
||||||
|
) -> MemoryPerceptualModel:
|
||||||
|
|
||||||
|
"""Create perceptual memory"""
|
||||||
|
|
||||||
|
db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, "
|
||||||
|
f"type={perceptual_type}, file={file_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
perceptual_memory = MemoryPerceptualModel(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
perceptual_type=perceptual_type,
|
||||||
|
storage_service=storage_service,
|
||||||
|
file_path=file_path,
|
||||||
|
file_name=file_name,
|
||||||
|
file_ext=file_ext,
|
||||||
|
summary=summary,
|
||||||
|
meta_data=meta_data,
|
||||||
|
created_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(perceptual_memory)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}")
|
||||||
|
return perceptual_memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ==================== Query ====================
|
||||||
|
def get_count_by_user_id(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
count = self.db.query(MemoryPerceptualModel).filter(
|
||||||
|
MemoryPerceptualModel.end_user_id == end_user_id
|
||||||
|
).count()
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_count_by_type(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
perceptual_type: PerceptualType,
|
||||||
|
):
|
||||||
|
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
count = self.db.query(MemoryPerceptualModel).filter(
|
||||||
|
MemoryPerceptualModel.end_user_id == end_user_id,
|
||||||
|
MemoryPerceptualModel.perceptual_type == perceptual_type
|
||||||
|
).count()
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_timeline(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
query: PerceptualQuerySchema
|
||||||
|
) -> Tuple[int, List[MemoryPerceptualModel]]:
|
||||||
|
"""Get the timeline of a user's perceptual memories"""
|
||||||
|
db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_query = self.db.query(MemoryPerceptualModel).filter(
|
||||||
|
MemoryPerceptualModel.end_user_id == end_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if query.filter.type is not None:
|
||||||
|
base_query = base_query.filter(
|
||||||
|
MemoryPerceptualModel.perceptual_type == query.filter.type
|
||||||
|
)
|
||||||
|
|
||||||
|
total_count = base_query.count()
|
||||||
|
|
||||||
|
memories = base_query.order_by(
|
||||||
|
desc(MemoryPerceptualModel.created_time)
|
||||||
|
).offset(
|
||||||
|
(query.page - 1) * query.page_size
|
||||||
|
).limit(query.page_size).all()
|
||||||
|
|
||||||
|
db_logger.info(
|
||||||
|
f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}")
|
||||||
|
return total_count, memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_type(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
perceptual_type: PerceptualType,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0
|
||||||
|
) -> List[MemoryPerceptualModel]:
|
||||||
|
"""Get memories by perceptual type"""
|
||||||
|
db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
memories = self.db.query(MemoryPerceptualModel).filter(
|
||||||
|
and_(
|
||||||
|
MemoryPerceptualModel.end_user_id == end_user_id,
|
||||||
|
MemoryPerceptualModel.perceptual_type == perceptual_type
|
||||||
|
)
|
||||||
|
).order_by(
|
||||||
|
desc(MemoryPerceptualModel.created_time)
|
||||||
|
).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
db_logger.debug(f"Query by type succeeded: count={len(memories)}")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, "
|
||||||
|
f"type={perceptual_type} - {str(e)}")
|
||||||
|
raise
|
||||||
503
api/app/repositories/memory_short_repository.py
Normal file
503
api/app/repositories/memory_short_repository.py
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
"""
|
||||||
|
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
|
||||||
|
"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
|
|
||||||
|
# 获取数据库专用日志器
|
||||||
|
db_logger = get_db_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ShortTermMemoryRepository:
|
||||||
|
"""短期记忆仓储类"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
|
||||||
|
"""创建短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
messages: 用户消息内容
|
||||||
|
aimessages: AI回复消息内容
|
||||||
|
search_switch: 搜索开关状态
|
||||||
|
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ShortTermMemory: 创建的短期记忆对象
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
memory = ShortTermMemory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=messages,
|
||||||
|
aimessages=aimessages,
|
||||||
|
search_switch=search_switch,
|
||||||
|
retrieved_content=retrieved_content or []
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(memory)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(memory)
|
||||||
|
|
||||||
|
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
|
||||||
|
return memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def count_by_user_id(self,end_user_id: str) -> int:
|
||||||
|
"""根据ID获取短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[ShortTermMemory]: 记忆对象,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(ShortTermMemory.end_user_id == end_user_id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
|
||||||
|
"""获取用户最新的短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
limit: 返回记录数限制,默认5条
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用复合索引 ix_memory_short_term_user_time 优化查询
|
||||||
|
memories = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(ShortTermMemory.end_user_id == end_user_id)
|
||||||
|
.order_by(ShortTermMemory.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
|
||||||
|
"""获取用户最近指定小时内的短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
hours: 时间范围(小时),默认24小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
|
||||||
|
|
||||||
|
# 使用复合索引 ix_memory_short_term_user_time 优化查询
|
||||||
|
memories = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(
|
||||||
|
ShortTermMemory.end_user_id == end_user_id,
|
||||||
|
ShortTermMemory.created_at >= cutoff_time
|
||||||
|
)
|
||||||
|
.order_by(ShortTermMemory.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
|
||||||
|
"""删除指定ID的短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除成功返回True,否则返回False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deleted_count = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(ShortTermMemory.id == memory_id)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_old_memories(self, days: int = 7) -> int:
|
||||||
|
"""删除指定天数之前的短期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 保留天数,默认7天
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||||
|
|
||||||
|
deleted_count = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(ShortTermMemory.created_at < cutoff_time)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
|
||||||
|
"""创建或更新短期记忆记录
|
||||||
|
|
||||||
|
根据 end_user_id、messages 和 aimessages 查找现有记录:
|
||||||
|
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
|
||||||
|
- 如果没有找到匹配的记录,则创建新记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
messages: 用户消息内容
|
||||||
|
aimessages: AI回复消息内容
|
||||||
|
search_switch: 搜索开关状态
|
||||||
|
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ShortTermMemory: 创建或更新的短期记忆对象
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
|
||||||
|
query_filters = [
|
||||||
|
ShortTermMemory.end_user_id == end_user_id,
|
||||||
|
ShortTermMemory.messages == messages
|
||||||
|
]
|
||||||
|
|
||||||
|
# 如果 aimessages 不为空,则加入查询条件
|
||||||
|
if aimessages is not None:
|
||||||
|
query_filters.append(ShortTermMemory.aimessages == aimessages)
|
||||||
|
else:
|
||||||
|
# 如果 aimessages 为 None,则查找 aimessages 为 NULL 的记录
|
||||||
|
query_filters.append(ShortTermMemory.aimessages.is_(None))
|
||||||
|
|
||||||
|
# 查找现有记录
|
||||||
|
existing_memory = (
|
||||||
|
self.db.query(ShortTermMemory)
|
||||||
|
.filter(*query_filters)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_memory:
|
||||||
|
# 更新现有记录
|
||||||
|
existing_memory.messages = messages
|
||||||
|
existing_memory.aimessages = aimessages
|
||||||
|
existing_memory.search_switch = search_switch
|
||||||
|
existing_memory.retrieved_content = retrieved_content or []
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(existing_memory)
|
||||||
|
|
||||||
|
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
|
||||||
|
return existing_memory
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
new_memory = ShortTermMemory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=messages,
|
||||||
|
aimessages=aimessages,
|
||||||
|
search_switch=search_switch,
|
||||||
|
retrieved_content=retrieved_content or []
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(new_memory)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(new_memory)
|
||||||
|
|
||||||
|
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
|
||||||
|
return new_memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class LongTermMemoryRepository:
|
||||||
|
"""长期记忆仓储类"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
|
||||||
|
"""创建长期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LongTermMemory: 创建的长期记忆对象
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
memory = LongTermMemory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
retrieved_content=retrieved_content or []
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(memory)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(memory)
|
||||||
|
|
||||||
|
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
|
||||||
|
return memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
|
||||||
|
"""根据ID获取长期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[LongTermMemory]: 记忆对象,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
memory = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(LongTermMemory.id == memory_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if memory:
|
||||||
|
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
|
||||||
|
else:
|
||||||
|
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
|
||||||
|
|
||||||
|
return memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
|
||||||
|
"""根据用户ID获取长期记忆记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
limit: 返回记录数限制,默认100
|
||||||
|
offset: 偏移量,默认0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用复合索引 ix_memory_long_term_user_time 优化查询
|
||||||
|
memories = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||||
|
.order_by(LongTermMemory.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
db_logger.info(f"成功查询用户 {end_user_id} 的 {len(memories)} 条长期记忆记录")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
|
||||||
|
"""根据内容关键词搜索长期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
keyword: 搜索关键词
|
||||||
|
limit: 返回记录数限制,默认50
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
|
||||||
|
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
|
||||||
|
memories = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(
|
||||||
|
LongTermMemory.end_user_id == end_user_id,
|
||||||
|
LongTermMemory.retrieved_content.astext.contains(keyword)
|
||||||
|
)
|
||||||
|
.order_by(LongTermMemory.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}' 的 {len(memories)} 条长期记忆记录")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
|
||||||
|
"""删除指定ID的长期记忆记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除成功返回True,否则返回False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deleted_count = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(LongTermMemory.id == memory_id)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def count_by_user_id(self, end_user_id: str) -> int:
|
||||||
|
"""统计用户的长期记忆记录数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 记录数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
|
||||||
|
"""创建或更新长期记忆记录
|
||||||
|
|
||||||
|
根据 end_user_id 和 retrieved_content 判断是否需要写入:
|
||||||
|
- 如果找到相同的 end_user_id 和 retrieved_content,则不写入,返回 None
|
||||||
|
- 如果没有找到相同的记录,则创建新记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
retrieved_content = retrieved_content or []
|
||||||
|
|
||||||
|
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
|
||||||
|
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
|
||||||
|
existing_memories = (
|
||||||
|
self.db.query(LongTermMemory)
|
||||||
|
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||||
|
.order_by(LongTermMemory.created_at.desc())
|
||||||
|
.limit(100) # 限制查询数量,避免加载过多数据
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在 Python 中比较 retrieved_content
|
||||||
|
for memory in existing_memories:
|
||||||
|
if memory.retrieved_content == retrieved_content:
|
||||||
|
# 如果找到相同的记录,不写入
|
||||||
|
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果没有找到相同的记录,创建新记录
|
||||||
|
new_memory = LongTermMemory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
retrieved_content=retrieved_content
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(new_memory)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(new_memory)
|
||||||
|
|
||||||
|
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
|
||||||
|
return new_memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -722,7 +722,12 @@ SET m += {
|
|||||||
chunk_ids: summary.chunk_ids,
|
chunk_ids: summary.chunk_ids,
|
||||||
content: summary.content,
|
content: summary.content,
|
||||||
summary_embedding: summary.summary_embedding,
|
summary_embedding: summary.summary_embedding,
|
||||||
config_id: summary.config_id
|
config_id: summary.config_id,
|
||||||
|
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,
|
||||||
|
activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END,
|
||||||
|
access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END,
|
||||||
|
last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END,
|
||||||
|
access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END
|
||||||
}
|
}
|
||||||
RETURN m.id AS uuid
|
RETURN m.id AS uuid
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
|||||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||||
n['importance_score'] = n.get('importance_score', 0.5)
|
n['importance_score'] = n.get('importance_score', 0.5)
|
||||||
n['activation_value'] = n.get('activation_value')
|
n['activation_value'] = n.get('activation_value')
|
||||||
n['access_history'] = n.get('access_history', [])
|
n['access_history'] = n.get('access_history') or []
|
||||||
n['last_access_time'] = n.get('last_access_time')
|
n['last_access_time'] = n.get('last_access_time')
|
||||||
n['access_count'] = n.get('access_count', 0)
|
n['access_count'] = n.get('access_count', 0)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
|||||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||||
n['importance_score'] = n.get('importance_score', 0.5)
|
n['importance_score'] = n.get('importance_score', 0.5)
|
||||||
n['activation_value'] = n.get('activation_value')
|
n['activation_value'] = n.get('activation_value')
|
||||||
n['access_history'] = n.get('access_history', [])
|
n['access_history'] = n.get('access_history') or []
|
||||||
n['last_access_time'] = n.get('last_access_time')
|
n['last_access_time'] = n.get('last_access_time')
|
||||||
n['access_count'] = n.get('access_count', 0)
|
n['access_count'] = n.get('access_count', 0)
|
||||||
|
|
||||||
|
|||||||
133
api/app/schemas/memory_perceptual_schema.py
Normal file
133
api/app/schemas/memory_perceptual_schema.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualFilter(BaseModel):
|
||||||
|
type: PerceptualType | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Perceptual type used for filtering the query; optional"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualQuerySchema(BaseModel):
|
||||||
|
filter: PerceptualFilter = Field(
|
||||||
|
default_factory=lambda: PerceptualFilter(),
|
||||||
|
description="Query filter containing perceptual type criteria"
|
||||||
|
)
|
||||||
|
|
||||||
|
page: int = Field(
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
description="Page number for pagination, starting from 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
page_size: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Number of records per page, range 1-100"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualMemoryItem(BaseModel):
|
||||||
|
"""感知记忆项"""
|
||||||
|
id: uuid.UUID = Field(..., description="Unique memory ID")
|
||||||
|
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
|
||||||
|
file_path: str = Field(..., description="File path in the storage service")
|
||||||
|
file_name: str = Field(..., description="File name")
|
||||||
|
summary: Optional[str] = Field(None, description="摘要")
|
||||||
|
storage_type: FileStorageType = Field(..., description="Storage type for file")
|
||||||
|
created_time: Optional[datetime] = Field(None, description="创建时间")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualTimelineResponse(BaseModel):
|
||||||
|
"""感知记忆时间线响应"""
|
||||||
|
total: int = Field(..., description="总数量")
|
||||||
|
page: int = Field(..., description="当前页码")
|
||||||
|
page_size: int = Field(..., description="每页大小")
|
||||||
|
total_pages: int = Field(..., description="总页数")
|
||||||
|
memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# TODO: FileMetaData
|
||||||
|
# --------------------------
|
||||||
|
class Identity(BaseModel):
|
||||||
|
title: str
|
||||||
|
filename: str
|
||||||
|
source: str # upload | crawl | system
|
||||||
|
author: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Semantic(BaseModel):
|
||||||
|
topic: str
|
||||||
|
domain: str
|
||||||
|
difficulty: str # beginner | intermediate | advanced
|
||||||
|
intent: str # informative | instructional | promotional
|
||||||
|
sentiment: str # positive | neutral | negative
|
||||||
|
|
||||||
|
|
||||||
|
class Content(BaseModel):
|
||||||
|
summary: str
|
||||||
|
keywords: list[str]
|
||||||
|
topic: str
|
||||||
|
domain: str
|
||||||
|
|
||||||
|
|
||||||
|
class Usage(BaseModel):
|
||||||
|
target_audience: list[str]
|
||||||
|
use_cases: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Stats(BaseModel):
|
||||||
|
duration_sec: Optional[int] = None
|
||||||
|
char_count: int
|
||||||
|
word_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class Processing(BaseModel):
|
||||||
|
transcribed: bool
|
||||||
|
ocr_applied: bool
|
||||||
|
chunked: bool
|
||||||
|
vectorized: bool
|
||||||
|
embedding_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VideoModal(BaseModel):
|
||||||
|
scene: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioModal(BaseModel):
|
||||||
|
speaker_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class TextModal(BaseModel):
|
||||||
|
section_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class Asset(BaseModel):
|
||||||
|
type: str
|
||||||
|
modality: str # text | audio | video
|
||||||
|
format: str # docx | mp3 | mp4
|
||||||
|
language: str
|
||||||
|
encoding: str
|
||||||
|
|
||||||
|
identity: Identity
|
||||||
|
semantic: Semantic
|
||||||
|
content: Content
|
||||||
|
usage: Usage
|
||||||
|
stats: Stats
|
||||||
|
processing: Processing
|
||||||
|
created_at: str
|
||||||
|
modalities: AudioModal | TextModal | VideoModal
|
||||||
@@ -4,6 +4,7 @@ Memory Agent Service
|
|||||||
Handles business logic for memory agent operations including read/write services,
|
Handles business logic for memory agent operations including read/write services,
|
||||||
health checks, and message type classification.
|
health checks, and message type classification.
|
||||||
"""
|
"""
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
|||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
|
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
@@ -393,7 +395,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ori_message=message
|
||||||
# Resolve config_id if None using end_user's connected config
|
# Resolve config_id if None using end_user's connected config
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
try:
|
try:
|
||||||
@@ -596,6 +598,30 @@ class MemoryAgentService:
|
|||||||
"has_answer": bool(final_answer)
|
"has_answer": bool(final_answer)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
retrieved_content=[]
|
||||||
|
repo = ShortTermMemoryRepository(db)
|
||||||
|
if str(search_switch)!="2":
|
||||||
|
for intermediate in intermediate_outputs:
|
||||||
|
intermediate_type=intermediate['type']
|
||||||
|
if intermediate_type=="search_result":
|
||||||
|
query=intermediate['query']
|
||||||
|
raw_results=intermediate['raw_results']
|
||||||
|
reranked_results=raw_results.get('reranked_results',[])
|
||||||
|
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||||
|
statements=list(set(statements))
|
||||||
|
retrieved_content.append({query:statements})
|
||||||
|
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
|
||||||
|
# 使用 upsert 方法
|
||||||
|
repo.upsert(
|
||||||
|
end_user_id=group_id, # 确保这个变量在作用域内
|
||||||
|
messages=ori_message,
|
||||||
|
aimessages=final_answer,
|
||||||
|
retrieved_content=retrieved_content,
|
||||||
|
search_switch=str(search_switch)
|
||||||
|
)
|
||||||
|
print("写入成功")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": final_answer,
|
"answer": final_answer,
|
||||||
|
|||||||
166
api/app/services/memory_perceptual_service.py
Normal file
166
api/app/services/memory_perceptual_service.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||||
|
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||||
|
from app.schemas.memory_perceptual_schema import (
|
||||||
|
PerceptualQuerySchema,
|
||||||
|
PerceptualTimelineResponse,
|
||||||
|
PerceptualMemoryItem,
|
||||||
|
AudioModal, Content, VideoModal, TextModal
|
||||||
|
)
|
||||||
|
|
||||||
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryPerceptualService:
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.repository = MemoryPerceptualRepository(db)
|
||||||
|
|
||||||
|
def get_memory_count(self, end_user_id: uuid.UUID) -> Dict[str, Any]:
|
||||||
|
"""Retrieve perceptual memory statistics for a user."""
|
||||||
|
business_logger.info(f"Fetching perceptual memory statistics: end_user_id={end_user_id}")
|
||||||
|
try:
|
||||||
|
total_count = self.repository.get_count_by_user_id(end_user_id=end_user_id)
|
||||||
|
|
||||||
|
vision_count = self.repository.get_count_by_type(end_user_id, PerceptualType.VISION)
|
||||||
|
audio_count = self.repository.get_count_by_type(end_user_id, PerceptualType.AUDIO)
|
||||||
|
text_count = self.repository.get_count_by_type(end_user_id, PerceptualType.TEXT)
|
||||||
|
conversation_count = self.repository.get_count_by_type(end_user_id, PerceptualType.CONVERSATION)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total": total_count,
|
||||||
|
"by_type": {
|
||||||
|
"vision": vision_count,
|
||||||
|
"audio": audio_count,
|
||||||
|
"text": text_count,
|
||||||
|
"conversation": conversation_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
business_logger.info(f"Memory statistics fetched successfully: total={total_count}")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"Failed to fetch memory statistics: {str(e)}")
|
||||||
|
raise BusinessException(f"Failed to fetch memory statistics: {str(e)}", BizCode.DB_ERROR)
|
||||||
|
|
||||||
|
def _get_latest_memory_by_type(
|
||||||
|
self,
|
||||||
|
end_user_id: uuid.UUID,
|
||||||
|
perceptual_type: PerceptualType
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
"""Internal helper to retrieve the latest memory by type."""
|
||||||
|
business_logger.info(f"Fetching latest {perceptual_type.name.lower()} memory: end_user_id={end_user_id}")
|
||||||
|
try:
|
||||||
|
memories = self.repository.get_by_type(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
perceptual_type=perceptual_type,
|
||||||
|
limit=1,
|
||||||
|
offset=0
|
||||||
|
)
|
||||||
|
if not memories:
|
||||||
|
business_logger.info(f"No {perceptual_type.name.lower()} memory found: end_user_id={end_user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory = memories[0]
|
||||||
|
meta_data = memory.meta_data or {}
|
||||||
|
modalities = meta_data.get("modalities")
|
||||||
|
content = meta_data.get("content")
|
||||||
|
|
||||||
|
if not modalities:
|
||||||
|
raise BusinessException(f"Modalities not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
|
||||||
|
if not content:
|
||||||
|
raise BusinessException(f"Content not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
|
||||||
|
content = Content(**content)
|
||||||
|
match perceptual_type:
|
||||||
|
case PerceptualType.VISION:
|
||||||
|
modal = VideoModal(**modalities)
|
||||||
|
case PerceptualType.AUDIO:
|
||||||
|
modal = AudioModal(**modalities)
|
||||||
|
case PerceptualType.TEXT:
|
||||||
|
modal = TextModal(**modalities)
|
||||||
|
case _:
|
||||||
|
raise BusinessException("Unsupported perceptual type", BizCode.DB_ERROR)
|
||||||
|
detail = modal.model_dump()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": str(memory.id),
|
||||||
|
"file_name": memory.file_name,
|
||||||
|
"file_path": memory.file_path,
|
||||||
|
"storage_type": memory.storage_service,
|
||||||
|
"summary": memory.summary,
|
||||||
|
"keywords": content.keywords,
|
||||||
|
"topic": content.topic,
|
||||||
|
"domain": content.domain,
|
||||||
|
"created_time": memory.created_time.isoformat() if memory.created_time else None,
|
||||||
|
**detail
|
||||||
|
}
|
||||||
|
|
||||||
|
business_logger.info(
|
||||||
|
f"Latest {perceptual_type.name.lower()} memory retrieved successfully: file={memory.file_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||||
|
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||||
|
BizCode.DB_ERROR)
|
||||||
|
|
||||||
|
def get_latest_visual_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._get_latest_memory_by_type(end_user_id, PerceptualType.VISION)
|
||||||
|
|
||||||
|
def get_latest_audio_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._get_latest_memory_by_type(end_user_id, PerceptualType.AUDIO)
|
||||||
|
|
||||||
|
def get_latest_text_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._get_latest_memory_by_type(end_user_id, PerceptualType.TEXT)
|
||||||
|
|
||||||
|
def get_time_line(self, end_user_id: uuid.UUID, query: PerceptualQuerySchema) -> PerceptualTimelineResponse:
|
||||||
|
"""Retrieve a timeline of perceptual memories for a user."""
|
||||||
|
business_logger.info(f"Fetching perceptual memory timeline: "
|
||||||
|
f"end_user_id={end_user_id}, filter={query.filter}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if query.page < 1:
|
||||||
|
raise BusinessException("Page number must be greater than 0", BizCode.INVALID_PARAMETER)
|
||||||
|
if query.page_size < 1 or query.page_size > 100:
|
||||||
|
raise BusinessException("Page size must be between 1 and 100", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
total_count, memories = self.repository.get_timeline(end_user_id, query)
|
||||||
|
|
||||||
|
memory_items = []
|
||||||
|
for memory in memories:
|
||||||
|
memory_item = PerceptualMemoryItem(
|
||||||
|
id=memory.id,
|
||||||
|
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||||
|
file_path=memory.file_path,
|
||||||
|
file_name=memory.file_name,
|
||||||
|
summary=memory.summary,
|
||||||
|
created_time=memory.created_time,
|
||||||
|
storage_type=FileStorageType(memory.storage_service),
|
||||||
|
)
|
||||||
|
memory_items.append(memory_item)
|
||||||
|
|
||||||
|
timeline_response = PerceptualTimelineResponse(
|
||||||
|
total=total_count,
|
||||||
|
page=query.page,
|
||||||
|
page_size=query.page_size,
|
||||||
|
total_pages=(total_count + query.page_size - 1) // query.page_size,
|
||||||
|
memories=memory_items
|
||||||
|
)
|
||||||
|
|
||||||
|
business_logger.info(f"Perceptual memory timeline retrieved successfully: "
|
||||||
|
f"total={total_count}, returned={len(memories)}")
|
||||||
|
return timeline_response
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||||
|
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||||
56
api/app/services/memory_short_service.py
Normal file
56
api/app/services/memory_short_service.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||||
|
|
||||||
|
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
db=next(get_db())
|
||||||
|
class ShortService:
|
||||||
|
def __init__(self, end_user_id):
|
||||||
|
self.short_repo = ShortTermMemoryRepository(db)
|
||||||
|
self.end_user_id = end_user_id
|
||||||
|
|
||||||
|
def get_short_databasets(self):
|
||||||
|
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
||||||
|
short_result = []
|
||||||
|
for memory in short_memories:
|
||||||
|
deep_expanded = {} # Create a new dictionary for each memory
|
||||||
|
messages = memory.messages
|
||||||
|
aimessages = memory.aimessages
|
||||||
|
retrieved_content = memory.retrieved_content or []
|
||||||
|
|
||||||
|
api_logger.debug(f"Retrieved content: {retrieved_content}")
|
||||||
|
|
||||||
|
retrieval_source = []
|
||||||
|
for item in retrieved_content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
for key, values in item.items():
|
||||||
|
retrieval_source.append({"query": key, "retrieval": values})
|
||||||
|
|
||||||
|
deep_expanded['retrieval'] = retrieval_source
|
||||||
|
deep_expanded['message'] = messages # 修正拼写错误
|
||||||
|
deep_expanded['answer'] = aimessages
|
||||||
|
short_result.append(deep_expanded)
|
||||||
|
return short_result
|
||||||
|
def get_short_count(self):
|
||||||
|
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
||||||
|
return short_count
|
||||||
|
|
||||||
|
class LongService:
|
||||||
|
def __init__(self, end_user_id):
|
||||||
|
self.long_repo = LongTermMemoryRepository(db)
|
||||||
|
self.end_user_id = end_user_id
|
||||||
|
def get_long_databasets(self):
|
||||||
|
# 获取长期记忆数据
|
||||||
|
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
||||||
|
|
||||||
|
long_result = []
|
||||||
|
for long_memory in long_memories:
|
||||||
|
if long_memory.retrieved_content:
|
||||||
|
for memory_item in long_memory.retrieved_content:
|
||||||
|
if isinstance(memory_item, dict):
|
||||||
|
for key, values in memory_item.items():
|
||||||
|
long_result.append({"query": key, "retrieval": values})
|
||||||
|
return long_result
|
||||||
@@ -166,6 +166,8 @@ class PromptOptimizerService:
|
|||||||
model_config = self.get_model_config(tenant_id, model_id)
|
model_config = self.get_model_config(tenant_id, model_id)
|
||||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
|
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||||
|
|
||||||
# Create LLM instance
|
# Create LLM instance
|
||||||
api_config: ModelApiKey = model_config.api_keys[0]
|
api_config: ModelApiKey = model_config.api_keys[0]
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
llm = RedBearLLM(RedBearModelConfig(
|
||||||
@@ -203,7 +205,6 @@ class PromptOptimizerService:
|
|||||||
|
|
||||||
messages.extend(session_history[:-1]) # last message is current message
|
messages.extend(session_history[:-1]) # last message is current message
|
||||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||||
logger.info(f"Prompt optimization message: {messages}")
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
prompt_started = False
|
prompt_started = False
|
||||||
prompt_finished = False
|
prompt_finished = False
|
||||||
@@ -250,6 +251,7 @@ class PromptOptimizerService:
|
|||||||
content=desc
|
content=desc
|
||||||
)
|
)
|
||||||
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||||
|
logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}")
|
||||||
yield {"desc": optim_result.get("desc"), "variables": variables}
|
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
|
|||||||
field_whitelist = {
|
field_whitelist = {
|
||||||
"Dialogue": ["content", "created_at"],
|
"Dialogue": ["content", "created_at"],
|
||||||
"Chunk": ["content", "created_at"],
|
"Chunk": ["content", "created_at"],
|
||||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"],
|
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"],
|
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user