From c52b3600683bda7950a76238e0aed61f28e28346 Mon Sep 17 00:00:00 2001 From: Eternity <61316157+myhMARS@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:00:22 +0800 Subject: [PATCH 1/3] Feature/memory perceptual (#48) * perf(workflow): pass JSON data to HTTP node as a string * perf(prompt_opt): simplify log output * feat(memory): add perceptual memory page API and related database schema * perf(log): clean up API exception log output * perf(memory): simplify perceptual memory timeline response by removing metadata --- .../memory_perceptual_controller.py | 255 ++++++++++++++++++ .../workflow/nodes/http_request/config.py | 6 +- .../core/workflow/nodes/http_request/node.py | 2 +- api/app/models/memory_perceptual_model.py | 40 +++ .../memory_perceptual_repository.py | 156 +++++++++++ api/app/schemas/memory_perceptual_schema.py | 133 +++++++++ api/app/services/memory_perceptual_service.py | 166 ++++++++++++ api/app/services/prompt_optimizer_service.py | 4 +- 8 files changed, 758 insertions(+), 4 deletions(-) create mode 100644 api/app/controllers/memory_perceptual_controller.py create mode 100644 api/app/models/memory_perceptual_model.py create mode 100644 api/app/repositories/memory_perceptual_repository.py create mode 100644 api/app/schemas/memory_perceptual_schema.py create mode 100644 api/app/services/memory_perceptual_service.py diff --git a/api/app/controllers/memory_perceptual_controller.py b/api/app/controllers/memory_perceptual_controller.py new file mode 100644 index 00000000..5154c763 --- /dev/null +++ b/api/app/controllers/memory_perceptual_controller.py @@ -0,0 +1,255 @@ +import uuid +from typing import Optional + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.models.memory_perceptual_model import PerceptualType +from app.schemas.memory_perceptual_schema import ( + PerceptualQuerySchema, + PerceptualFilter +) +from app.schemas.response_schema import ApiResponse +from app.services.memory_perceptual_service import MemoryPerceptualService + +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/perceptual", + tags=["Perceptual Memory System"], + dependencies=[Depends(get_current_user)] +) + + +@router.get("/{group_id}/count", response_model=ApiResponse) +def get_memory_count( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve perceptual memory statistics for a user group. + + Args: + group_id: ID of the user group (usually end_user_id in this context) + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Response containing memory count statistics + """ + api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + count_stats = service.get_memory_count(group_id) + + api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}") + + return success( + data=count_stats, + msg="Memory statistics retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch memory statistics", + ) + + +@router.get("/{group_id}/last_visual", response_model=ApiResponse) +def get_last_visual_memory( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent VISION-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest visual memory + """ + api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + visual_memory = service.get_latest_visual_memory(group_id) + + if visual_memory is None: + api_logger.info(f"No visual memory found: group_id={group_id}") + return success( + data=None, + msg="No visual memory available" + ) + + api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}") + + return success( + data=visual_memory, + msg="Latest visual memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest visual memory", + ) + + +@router.get("/{group_id}/last_listen", response_model=ApiResponse) +def get_last_memory_listen( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent AUDIO-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest audio memory + """ + api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + audio_memory = service.get_latest_audio_memory(group_id) + + if audio_memory is None: + api_logger.info(f"No audio memory found: group_id={group_id}") + return success( + data=None, + msg="No audio memory available" + ) + + api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}") + + return success( + data=audio_memory, + msg="Latest audio memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest audio memory", + ) + + +@router.get("/{group_id}/last_text", response_model=ApiResponse) +def get_last_text_memory( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent TEXT-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest text memory + """ + api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}") + + try: + # 调用服务层获取最近的文本记忆 + service = MemoryPerceptualService(db) + text_memory = service.get_latest_text_memory(group_id) + + if text_memory is None: + api_logger.info(f"No text memory found: group_id={group_id}") + return success( + data=None, + msg="No text memory available" + ) + + api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}") + + return success( + data=text_memory, + msg="Latest text memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest text memory", + ) + + +@router.get("/{group_id}/timeline", response_model=ApiResponse) +def get_memory_time_line( + group_id: uuid.UUID, + perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"), + page: int = Query(1, ge=1, description="页码"), + page_size: int = Query(10, ge=1, le=100, description="每页大小"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve a timeline of perceptual memories for a user group. + + Args: + group_id: ID of the user group + perceptual_type: Optional filter for perceptual type + page: Page number for pagination + page_size: Number of items per page + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Timeline data of perceptual memories + """ + api_logger.info( + f"Fetching perceptual memory timeline: user={current_user.username}, " + f"group_id={group_id}, type={perceptual_type}, page={page}" + ) + + try: + query = PerceptualQuerySchema( + filter=PerceptualFilter(type=perceptual_type), + page=page, + page_size=page_size + ) + + service = MemoryPerceptualService(db) + timeline_data = service.get_time_line(group_id, query) + + api_logger.info( + f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, " + f"returned={len(timeline_data.memories)}" + ) + + return success( + data=timeline_data.model_dump(), + msg="Perceptual memory timeline retrieved successfully" + ) + + except Exception as e: + api_logger.error( + f"Failed to fetch perceptual memory timeline: group_id={group_id}, " + f"error={str(e)}" + ) + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch perceptual memory timeline", + ) diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index 6bb7baaf..810a716f 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel): content_type = info.data.get("content_type") if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData): raise ValueError("When content_type is 'form-data', data must be of type HttpFormData") - elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict): - raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object") + elif content_type in [HttpContentType.JSON] and not isinstance(v, str): + raise ValueError("When content_type is JSON, data must be of type str") + elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict): + raise ValueError("When content_type is x-www-form-urlencoded, data must be a object") elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str): raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)") return v diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 4374d847..2e5de796 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode): return {} case HttpContentType.JSON: content["json"] = json.loads(self._render_template( - json.dumps(self.typed_config.body.data), state + self.typed_config.body.data, state )) case HttpContentType.FROM_DATA: data = {} diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py new file mode 100644 index 00000000..59eb0222 --- /dev/null +++ b/api/app/models/memory_perceptual_model.py @@ -0,0 +1,40 @@ +import datetime +import uuid +from enum import IntEnum + +from sqlalchemy import Column, ForeignKey, Integer, DateTime, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB + +from app.db import Base + + +class PerceptualType(IntEnum): + VISION = 1 + AUDIO = 2 + TEXT = 3 + CONVERSATION = 4 + + +class FileStorageType(IntEnum): + LOCAL = 1 + REMOTE = 2 + + +class MemoryPerceptualModel(Base): + __tablename__ = "memory_perceptual" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True) + + perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型") + + storage_service = Column(Integer, default=0, comment="存储服务类型") + file_path = Column(String, nullable=False, comment="文件路径") + file_name = Column(String, nullable=False, comment="文件名称") + file_ext = Column(String, nullable=False, comment="文件后缀名") + + summary = Column(String, comment="摘要") + meta_data = Column(JSONB, comment="元信息") + + created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间") diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py new file mode 100644 index 00000000..8415c2d0 --- /dev/null +++ b/api/app/repositories/memory_perceptual_repository.py @@ -0,0 +1,156 @@ +import uuid +from datetime import datetime +from typing import List, Tuple, Optional + +from sqlalchemy import and_, desc +from sqlalchemy.orm import Session + +from app.core.logging_config import get_db_logger +from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType +from app.schemas.memory_perceptual_schema import PerceptualQuerySchema + +db_logger = get_db_logger() + + +class MemoryPerceptualRepository: + """Data Access Layer for perceptual memory""" + + def __init__(self, db: Session): + self.db = db + + # ==================== Create and update ==================== + def create_perceptual_memory( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + file_path: str, + file_name: str, + file_ext: str, + summary: Optional[str] = None, + meta_data: Optional[dict] = None, + storage_service: FileStorageType = FileStorageType.LOCAL + + ) -> MemoryPerceptualModel: + + """Create perceptual memory""" + + db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, " + f"type={perceptual_type}, file={file_name}") + + try: + perceptual_memory = MemoryPerceptualModel( + end_user_id=end_user_id, + perceptual_type=perceptual_type, + storage_service=storage_service, + file_path=file_path, + file_name=file_name, + file_ext=file_ext, + summary=summary, + meta_data=meta_data, + created_time=datetime.now() + ) + + self.db.add(perceptual_memory) + self.db.flush() + + db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}") + return perceptual_memory + + except Exception as e: + db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}") + raise + + # ==================== Query ==================== + def get_count_by_user_id( + self, + end_user_id: uuid.UUID, + ): + db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}") + + try: + count = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id + ).count() + return count + except Exception as e: + db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}") + raise + + def get_count_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + ): + db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}") + + try: + count = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id, + MemoryPerceptualModel.perceptual_type == perceptual_type + ).count() + return count + except Exception as e: + db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}") + raise + + def get_timeline( + self, + end_user_id: uuid.UUID, + query: PerceptualQuerySchema + ) -> Tuple[int, List[MemoryPerceptualModel]]: + """Get the timeline of a user's perceptual memories""" + db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}") + + try: + base_query = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id + ) + + if query.filter.type is not None: + base_query = base_query.filter( + MemoryPerceptualModel.perceptual_type == query.filter.type + ) + + total_count = base_query.count() + + memories = base_query.order_by( + desc(MemoryPerceptualModel.created_time) + ).offset( + (query.page - 1) * query.page_size + ).limit(query.page_size).all() + + db_logger.info( + f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}") + return total_count, memories + + except Exception as e: + db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}") + raise + + def get_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + limit: int = 10, + offset: int = 0 + ) -> List[MemoryPerceptualModel]: + """Get memories by perceptual type""" + db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}") + + try: + memories = self.db.query(MemoryPerceptualModel).filter( + and_( + MemoryPerceptualModel.end_user_id == end_user_id, + MemoryPerceptualModel.perceptual_type == perceptual_type + ) + ).order_by( + desc(MemoryPerceptualModel.created_time) + ).offset(offset).limit(limit).all() + + db_logger.debug(f"Query by type succeeded: count={len(memories)}") + return memories + + except Exception as e: + db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, " + f"type={perceptual_type} - {str(e)}") + raise diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py new file mode 100644 index 00000000..41b74a36 --- /dev/null +++ b/api/app/schemas/memory_perceptual_schema.py @@ -0,0 +1,133 @@ +import uuid +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + +from app.models.memory_perceptual_model import PerceptualType, FileStorageType + + +class PerceptualFilter(BaseModel): + type: PerceptualType | None = Field( + default=None, + description="Perceptual type used for filtering the query; optional" + ) + + +class PerceptualQuerySchema(BaseModel): + filter: PerceptualFilter = Field( + default_factory=lambda: PerceptualFilter(), + description="Query filter containing perceptual type criteria" + ) + + page: int = Field( + default=1, + ge=1, + description="Page number for pagination, starting from 1" + ) + + page_size: int = Field( + default=10, + ge=1, + le=100, + description="Number of records per page, range 1-100" + ) + + +class PerceptualMemoryItem(BaseModel): + """感知记忆项""" + id: uuid.UUID = Field(..., description="Unique memory ID") + perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video") + file_path: str = Field(..., description="File path in the storage service") + file_name: str = Field(..., description="File name") + summary: Optional[str] = Field(None, description="摘要") + storage_type: FileStorageType = Field(..., description="Storage type for file") + created_time: Optional[datetime] = Field(None, description="创建时间") + + class Config: + from_attributes = True + + +class PerceptualTimelineResponse(BaseModel): + """感知记忆时间线响应""" + total: int = Field(..., description="总数量") + page: int = Field(..., description="当前页码") + page_size: int = Field(..., description="每页大小") + total_pages: int = Field(..., description="总页数") + memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表") + + class Config: + from_attributes = True + + +# -------------------------- +# TODO: FileMetaData +# -------------------------- +class Identity(BaseModel): + title: str + filename: str + source: str # upload | crawl | system + author: Optional[str] = None + + +class Semantic(BaseModel): + topic: str + domain: str + difficulty: str # beginner | intermediate | advanced + intent: str # informative | instructional | promotional + sentiment: str # positive | neutral | negative + + +class Content(BaseModel): + summary: str + keywords: list[str] + topic: str + domain: str + + +class Usage(BaseModel): + target_audience: list[str] + use_cases: list[str] + + +class Stats(BaseModel): + duration_sec: Optional[int] = None + char_count: int + word_count: int + + +class Processing(BaseModel): + transcribed: bool + ocr_applied: bool + chunked: bool + vectorized: bool + embedding_model: Optional[str] = None + + +class VideoModal(BaseModel): + scene: list[str] + + +class AudioModal(BaseModel): + speaker_count: int + + +class TextModal(BaseModel): + section_count: int + + +class Asset(BaseModel): + type: str + modality: str # text | audio | video + format: str # docx | mp3 | mp4 + language: str + encoding: str + + identity: Identity + semantic: Semantic + content: Content + usage: Usage + stats: Stats + processing: Processing + created_at: str + modalities: AudioModal | TextModal | VideoModal diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py new file mode 100644 index 00000000..a74dc5a7 --- /dev/null +++ b/api/app/services/memory_perceptual_service.py @@ -0,0 +1,166 @@ +import uuid +from typing import Dict, Any, Optional + +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.models.memory_perceptual_model import PerceptualType, FileStorageType +from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository +from app.schemas.memory_perceptual_schema import ( + PerceptualQuerySchema, + PerceptualTimelineResponse, + PerceptualMemoryItem, + AudioModal, Content, VideoModal, TextModal +) + +business_logger = get_business_logger() + + +class MemoryPerceptualService: + def __init__(self, db: Session): + self.db = db + self.repository = MemoryPerceptualRepository(db) + + def get_memory_count(self, end_user_id: uuid.UUID) -> Dict[str, Any]: + """Retrieve perceptual memory statistics for a user.""" + business_logger.info(f"Fetching perceptual memory statistics: end_user_id={end_user_id}") + try: + total_count = self.repository.get_count_by_user_id(end_user_id=end_user_id) + + vision_count = self.repository.get_count_by_type(end_user_id, PerceptualType.VISION) + audio_count = self.repository.get_count_by_type(end_user_id, PerceptualType.AUDIO) + text_count = self.repository.get_count_by_type(end_user_id, PerceptualType.TEXT) + conversation_count = self.repository.get_count_by_type(end_user_id, PerceptualType.CONVERSATION) + + stats = { + "total": total_count, + "by_type": { + "vision": vision_count, + "audio": audio_count, + "text": text_count, + "conversation": conversation_count + } + } + + business_logger.info(f"Memory statistics fetched successfully: total={total_count}") + return stats + + except Exception as e: + business_logger.error(f"Failed to fetch memory statistics: {str(e)}") + raise BusinessException(f"Failed to fetch memory statistics: {str(e)}", BizCode.DB_ERROR) + + def _get_latest_memory_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType + ) -> Optional[dict[str, Any]]: + """Internal helper to retrieve the latest memory by type.""" + business_logger.info(f"Fetching latest {perceptual_type.name.lower()} memory: end_user_id={end_user_id}") + try: + memories = self.repository.get_by_type( + end_user_id=end_user_id, + perceptual_type=perceptual_type, + limit=1, + offset=0 + ) + if not memories: + business_logger.info(f"No {perceptual_type.name.lower()} memory found: end_user_id={end_user_id}") + return None + + memory = memories[0] + meta_data = memory.meta_data or {} + modalities = meta_data.get("modalities") + content = meta_data.get("content") + + if not modalities: + raise BusinessException(f"Modalities not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR) + if not content: + raise BusinessException(f"Content not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR) + content = Content(**content) + match perceptual_type: + case PerceptualType.VISION: + modal = VideoModal(**modalities) + case PerceptualType.AUDIO: + modal = AudioModal(**modalities) + case PerceptualType.TEXT: + modal = TextModal(**modalities) + case _: + raise BusinessException("Unsupported perceptual type", BizCode.DB_ERROR) + detail = modal.model_dump() + + result = { + "id": str(memory.id), + "file_name": memory.file_name, + "file_path": memory.file_path, + "storage_type": memory.storage_service, + "summary": memory.summary, + "keywords": content.keywords, + "topic": content.topic, + "domain": content.domain, + "created_time": memory.created_time.isoformat() if memory.created_time else None, + **detail + } + + business_logger.info( + f"Latest {perceptual_type.name.lower()} memory retrieved successfully: file={memory.file_name}") + return result + + except Exception as e: + business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}") + raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", + BizCode.DB_ERROR) + + def get_latest_visual_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.VISION) + + def get_latest_audio_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.AUDIO) + + def get_latest_text_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.TEXT) + + def get_time_line(self, end_user_id: uuid.UUID, query: PerceptualQuerySchema) -> PerceptualTimelineResponse: + """Retrieve a timeline of perceptual memories for a user.""" + business_logger.info(f"Fetching perceptual memory timeline: " + f"end_user_id={end_user_id}, filter={query.filter}") + + try: + if query.page < 1: + raise BusinessException("Page number must be greater than 0", BizCode.INVALID_PARAMETER) + if query.page_size < 1 or query.page_size > 100: + raise BusinessException("Page size must be between 1 and 100", BizCode.INVALID_PARAMETER) + + total_count, memories = self.repository.get_timeline(end_user_id, query) + + memory_items = [] + for memory in memories: + memory_item = PerceptualMemoryItem( + id=memory.id, + perceptual_type=PerceptualType(memory.perceptual_type), + file_path=memory.file_path, + file_name=memory.file_name, + summary=memory.summary, + created_time=memory.created_time, + storage_type=FileStorageType(memory.storage_service), + ) + memory_items.append(memory_item) + + timeline_response = PerceptualTimelineResponse( + total=total_count, + page=query.page, + page_size=query.page_size, + total_pages=(total_count + query.page_size - 1) // query.page_size, + memories=memory_items + ) + + business_logger.info(f"Perceptual memory timeline retrieved successfully: " + f"total={total_count}, returned={len(memories)}") + return timeline_response + + except BusinessException: + raise + except Exception as e: + business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") + raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index b3ac1b79..135ddc5d 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -166,6 +166,8 @@ class PromptOptimizerService: model_config = self.get_model_config(tenant_id, model_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) + logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}") + # Create LLM instance api_config: ModelApiKey = model_config.api_keys[0] llm = RedBearLLM(RedBearModelConfig( @@ -203,7 +205,6 @@ class PromptOptimizerService: messages.extend(session_history[:-1]) # last message is current message messages.extend([(RoleType.USER.value, rendered_user_message)]) - logger.info(f"Prompt optimization message: {messages}") buffer = "" prompt_started = False prompt_finished = False @@ -250,6 +251,7 @@ class PromptOptimizerService: content=desc ) variables = self.parser_prompt_variables(optim_result.get("prompt")) + logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}") yield {"desc": optim_result.get("desc"), "variables": variables} @staticmethod From 5fe8043ff8204835b08eb25766b7a8a521717a8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:00:53 +0800 Subject: [PATCH 2/3] Fix/actr config (#49) * [fix]Remove the LLM * [fix]Failed to restore access history record --- .../forgetting_engine/access_history_manager.py | 6 +++--- api/app/repositories/neo4j/cypher_queries.py | 7 ++++++- api/app/repositories/neo4j/entity_repository.py | 2 +- api/app/repositories/neo4j/statement_repository.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index acc2a717..729a5542 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -246,7 +246,7 @@ class AccessHistoryManager: if not node_data: return ConsistencyCheckResult.CONSISTENT, None - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] last_access_time = node_data.get('last_access_time') access_count = node_data.get('access_count', 0) activation_value = node_data.get('activation_value') @@ -409,7 +409,7 @@ class AccessHistoryManager: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") return False - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) # 准备修复数据 @@ -530,7 +530,7 @@ class AccessHistoryManager: Returns: Dict[str, Any]: 更新数据,包含所有需要更新的字段 """ - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) # 追加新的访问时间 diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 259b1325..ed8f6100 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -722,7 +722,12 @@ SET m += { chunk_ids: summary.chunk_ids, content: summary.content, summary_embedding: summary.summary_embedding, - config_id: summary.config_id + config_id: summary.config_id, + importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END, + activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END, + access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END, + last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END, + access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END } RETURN m.id AS uuid """ diff --git a/api/app/repositories/neo4j/entity_repository.py b/api/app/repositories/neo4j/entity_repository.py index cb18feca..f4ca35c8 100644 --- a/api/app/repositories/neo4j/entity_repository.py +++ b/api/app/repositories/neo4j/entity_repository.py @@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]): # 处理 ACT-R 属性 - 确保字段存在且有默认值 n['importance_score'] = n.get('importance_score', 0.5) n['activation_value'] = n.get('activation_value') - n['access_history'] = n.get('access_history', []) + n['access_history'] = n.get('access_history') or [] n['last_access_time'] = n.get('last_access_time') n['access_count'] = n.get('access_count', 0) diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index 22343e10..cd9f2fac 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): # 处理 ACT-R 属性 - 确保字段存在且有默认值 n['importance_score'] = n.get('importance_score', 0.5) n['activation_value'] = n.get('activation_value') - n['access_history'] = n.get('access_history', []) + n['access_history'] = n.get('access_history') or [] n['last_access_time'] = n.get('last_access_time') n['access_count'] = n.get('access_count', 0) From bcb3d587a1dd98b3c775ffe410fe013dad0ca497 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:36:11 +0800 Subject: [PATCH 3/3] =?UTF-8?q?dev=E6=96=B0=E5=A2=9E=E7=9F=AD=E6=9C=9F?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E5=8A=9F=E8=83=BD=20(#47)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * dev新增短期记忆功能 * dev新增短期记忆功能 * dev新增短期记忆功能 * dev新增短期记忆功能 * dev新增短期记忆功能 * dev新增短期记忆功能 * dev新增短期记忆功能 --- api/app/controllers/__init__.py | 2 + .../memory_short_term_controller.py | 44 ++ api/app/core/agent/langchain_agent.py | 66 ++- api/app/models/__init__.py | 3 + api/app/models/memory_short_model.py | 60 +++ .../repositories/memory_short_repository.py | 503 ++++++++++++++++++ api/app/services/memory_agent_service.py | 72 ++- api/app/services/memory_short_service.py | 56 ++ api/app/services/user_memory_service.py | 4 +- 9 files changed, 765 insertions(+), 45 deletions(-) create mode 100644 api/app/controllers/memory_short_term_controller.py create mode 100644 api/app/models/memory_short_model.py create mode 100644 api/app/repositories/memory_short_repository.py create mode 100644 api/app/services/memory_short_service.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 0b07d0c9..9d4b9248 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -24,6 +24,7 @@ from . import ( memory_storage_controller, memory_dashboard_controller, memory_reflection_controller, + memory_short_term_controller, api_key_controller, release_share_controller, public_share_controller, @@ -71,6 +72,7 @@ manager_router.include_router(emotion_controller.router) manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) +manager_router.include_router(memory_short_term_controller.router) manager_router.include_router(tool_controller.router) manager_router.include_router(memory_forget_controller.router) manager_router.include_router(home_page_controller.router) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py new file mode 100644 index 00000000..f21a00b6 --- /dev/null +++ b/api/app/controllers/memory_short_term_controller.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from app.core.logging_config import get_api_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User + +from app.services.memory_storage_service import search_entity +from app.services.memory_short_service import ShortService,LongService +from dotenv import load_dotenv +from sqlalchemy.orm import Session +from typing import Optional +load_dotenv() +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/short", + tags=["Memory"], +) +@router.get("/short_term") +async def short_term_configs( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + # 获取短期记忆数据 + short_term=ShortService(end_user_id) + short_result=short_term.get_short_databasets() + short_count=short_term.get_short_count() + + long_term=LongService(end_user_id) + long_result=long_term.get_long_databasets() + + entity_result = await search_entity(end_user_id) + result = { + 'short_term': short_result, + 'long_term': long_result, + 'entity': entity_result.get('num', 0), + "retrieval_number":short_count, + "long_term_number":len(long_result) + } + + return success(data=result, msg="短期记忆系统数据获取成功") + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index ef9a489f..91445b12 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,13 +7,20 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ +import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence + +from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.services.memory_agent_service import ( + get_end_user_connected_config, +) from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task @@ -96,7 +103,8 @@ class LangChainAgent: "temperature": temperature, "streaming": streaming, "tool_count": len(self.tools), - "tool_names": [tool.name for tool in self.tools] if self.tools else [] + "tool_names": [tool.name for tool in self.tools] if self.tools else [], + "tool_count": len(self.tools) } ) @@ -137,11 +145,8 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - async def term_memory_save(self,messages,end_user_end,aimessages): - """ - 短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j - """ + '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' end_user_end=f"Term_{end_user_end}" print(messages) print(aimessages) @@ -155,17 +160,18 @@ class LangChainAgent: store.delete_duplicate_sessions() # logger.info(f'Redis_Agent:{end_user_end};{session_id}') return session_id - async def term_memory_redis_read(self,end_user_end): end_user_end = f"Term_{end_user_end}" history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) # logger.info(f'Redis_Agent:{end_user_end};{history}') messagss_list=[] + retrieved_content=[] for messages in history: query = messages.get("Query") aimessages = messages.get("Answer") messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - return messagss_list + retrieved_content.append({query: aimessages}) + return messagss_list,retrieved_content async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): @@ -205,7 +211,6 @@ class LangChainAgent: # If config_id is None, try to get from end_user's connected config if actual_config_id is None and end_user_id: try: - from app.db import get_db from app.services.memory_agent_service import ( get_end_user_connected_config, ) @@ -223,11 +228,26 @@ class LangChainAgent: logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - history_term_memory=await self.term_memory_redis_read(end_user_id) + history_term_memory_result = await self.term_memory_redis_read(end_user_id) + history_term_memory = history_term_memory_result[0] + db_for_memory = next(get_db()) if memory_flag: if len(history_term_memory)>=4 and storage_type != "rag": - history_term_memory=';'.join(history_term_memory) - logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + history_term_memory = ';'.join(history_term_memory) + retrieved_content = history_term_memory_result[1] + print(retrieved_content) + # 为长期记忆操作获取新的数据库连接 + try: + repo = LongTermMemoryRepository(db_for_memory) + repo.upsert(end_user_id, retrieved_content) + logger.info( + f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + except Exception as e: + logger.error(f"Failed to write to LongTermMemory: {e}") + raise + finally: + db_for_memory.close() + await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) try: @@ -316,10 +336,6 @@ class LangChainAgent: # If config_id is None, try to get from end_user's connected config if actual_config_id is None and end_user_id: try: - from app.db import get_db - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) db = next(get_db()) try: connected_config = get_end_user_connected_config(end_user_id, db) @@ -331,14 +347,24 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") - history_term_memory = await self.term_memory_redis_read(end_user_id) + history_term_memory_result = await self.term_memory_redis_read(end_user_id) + history_term_memory = history_term_memory_result[0] if memory_flag: if len(history_term_memory) >= 4 and storage_type != "rag": history_term_memory = ';'.join(history_term_memory) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, - history_term_memory, actual_config_id) + retrieved_content = history_term_memory_result[1] + db_for_memory = next(get_db()) + try: + repo = LongTermMemoryRepository(db_for_memory) + repo.upsert(end_user_id, retrieved_content) + logger.info( + f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, + history_term_memory, actual_config_id) + except Exception as e: + logger.error(f"Failed to write to long term memory: {e}") + finally: + db_for_memory.close() await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) try: diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 01dad24e..158e607e 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -6,6 +6,7 @@ from .document_model import Document from .file_model import File from .generic_file_model import GenericFile from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey +from .memory_short_model import ShortTermMemory, LongTermMemory from .knowledgeshare_model import KnowledgeShare from .app_model import App from .agent_app_config_model import AgentConfig @@ -67,6 +68,8 @@ __all__ = [ "BuiltinToolConfig", "CustomToolConfig", "MCPToolConfig", + "ShortTermMemory", + "LongTermMemory", "ToolExecution", "ToolType", "ToolStatus", diff --git a/api/app/models/memory_short_model.py b/api/app/models/memory_short_model.py new file mode 100644 index 00000000..6c3b1920 --- /dev/null +++ b/api/app/models/memory_short_model.py @@ -0,0 +1,60 @@ +""" +记忆模型 - 短期记忆和长期记忆表 +""" +import uuid +import datetime +from sqlalchemy import Column, String, DateTime, Text, JSON +from sqlalchemy.dialects.postgresql import UUID + +from app.db import Base + + +class ShortTermMemory(Base): + """短期记忆表 + + 用于存储临时的对话记忆,通常保存较短时间 + """ + __tablename__ = "memory_short_term" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID") + + # 用户信息 + end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID") + + # 对话内容 + messages = Column(Text, nullable=False, comment="用户消息内容") + aimessages = Column(Text, nullable=True, comment="AI回复消息内容") + + # 搜索开关 + search_switch = Column(String(50), nullable=True, comment="搜索开关状态") + + # 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}] + retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间") + + def __repr__(self): + return f"" + + +class LongTermMemory(Base): + """长期记忆表 + + 用于存储重要的对话记忆,长期保存 + """ + __tablename__ = "memory_long_term" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID") + + # 用户信息 + end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID") + + # 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}] + retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/api/app/repositories/memory_short_repository.py b/api/app/repositories/memory_short_repository.py new file mode 100644 index 00000000..9a6e39c6 --- /dev/null +++ b/api/app/repositories/memory_short_repository.py @@ -0,0 +1,503 @@ +""" +记忆仓储模块 - 短期记忆和长期记忆的数据访问层 +""" +from sqlalchemy.orm import Session +from typing import List, Optional, Dict, Any +import uuid +import datetime + +from app.models.memory_short_model import ShortTermMemory, LongTermMemory +from app.core.logging_config import get_db_logger + +# 获取数据库专用日志器 +db_logger = get_db_logger() + + +class ShortTermMemoryRepository: + """短期记忆仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory: + """创建短期记忆记录 + + Args: + end_user_id: 终端用户ID + messages: 用户消息内容 + aimessages: AI回复消息内容 + search_switch: 搜索开关状态 + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + ShortTermMemory: 创建的短期记忆对象 + """ + try: + memory = ShortTermMemory( + end_user_id=end_user_id, + messages=messages, + aimessages=aimessages, + search_switch=search_switch, + retrieved_content=retrieved_content or [] + ) + + self.db.add(memory) + self.db.commit() + self.db.refresh(memory) + + db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}") + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建短期记忆记录时出错: {str(e)}") + raise + + def count_by_user_id(self,end_user_id: str) -> int: + """根据ID获取短期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + Optional[ShortTermMemory]: 记忆对象,如果不存在则返回None + """ + try: + count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.end_user_id == end_user_id) + .count() + ) + db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}") + + return count + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}") + raise + + + def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]: + """获取用户最新的短期记忆记录 + + Args: + end_user_id: 终端用户ID + limit: 返回记录数限制,默认5条 + + Returns: + List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序 + """ + try: + # 使用复合索引 ix_memory_short_term_user_time 优化查询 + memories = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.end_user_id == end_user_id) + .order_by(ShortTermMemory.created_at.desc()) + .limit(limit) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}") + raise + + def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]: + """获取用户最近指定小时内的短期记忆记录 + + Args: + end_user_id: 终端用户ID + hours: 时间范围(小时),默认24小时 + + Returns: + List[ShortTermMemory]: 记忆记录列表,按创建时间倒序 + """ + try: + cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours) + + # 使用复合索引 ix_memory_short_term_user_time 优化查询 + memories = ( + self.db.query(ShortTermMemory) + .filter( + ShortTermMemory.end_user_id == end_user_id, + ShortTermMemory.created_at >= cutoff_time + ) + .order_by(ShortTermMemory.created_at.desc()) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}") + raise + + def delete_by_id(self, memory_id: uuid.UUID) -> bool: + """删除指定ID的短期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + bool: 删除成功返回True,否则返回False + """ + try: + deleted_count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.id == memory_id) + .delete(synchronize_session=False) + ) + + self.db.commit() + + if deleted_count > 0: + db_logger.info(f"成功删除短期记忆记录 {memory_id}") + return True + else: + db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def delete_old_memories(self, days: int = 7) -> int: + """删除指定天数之前的短期记忆记录 + + Args: + days: 保留天数,默认7天 + + Returns: + int: 删除的记录数 + """ + try: + cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days) + + deleted_count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.created_at < cutoff_time) + .delete(synchronize_session=False) + ) + + self.db.commit() + + db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录") + return deleted_count + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}") + raise + + def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory: + """创建或更新短期记忆记录 + + 根据 end_user_id、messages 和 aimessages 查找现有记录: + - 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content + - 如果没有找到匹配的记录,则创建新记录 + + Args: + end_user_id: 终端用户ID + messages: 用户消息内容 + aimessages: AI回复消息内容 + search_switch: 搜索开关状态 + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + ShortTermMemory: 创建或更新的短期记忆对象 + """ + try: + # 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询 + query_filters = [ + ShortTermMemory.end_user_id == end_user_id, + ShortTermMemory.messages == messages + ] + + # 如果 aimessages 不为空,则加入查询条件 + if aimessages is not None: + query_filters.append(ShortTermMemory.aimessages == aimessages) + else: + # 如果 aimessages 为 None,则查找 aimessages 为 NULL 的记录 + query_filters.append(ShortTermMemory.aimessages.is_(None)) + + # 查找现有记录 + existing_memory = ( + self.db.query(ShortTermMemory) + .filter(*query_filters) + .first() + ) + + if existing_memory: + # 更新现有记录 + existing_memory.messages = messages + existing_memory.aimessages = aimessages + existing_memory.search_switch = search_switch + existing_memory.retrieved_content = retrieved_content or [] + + self.db.commit() + self.db.refresh(existing_memory) + + db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}") + return existing_memory + else: + # 创建新记录 + new_memory = ShortTermMemory( + end_user_id=end_user_id, + messages=messages, + aimessages=aimessages, + search_switch=search_switch, + retrieved_content=retrieved_content or [] + ) + + self.db.add(new_memory) + self.db.commit() + self.db.refresh(new_memory) + + db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}") + return new_memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}") + raise + + +class LongTermMemoryRepository: + """长期记忆仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory: + """创建长期记忆记录 + + Args: + end_user_id: 终端用户ID + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + LongTermMemory: 创建的长期记忆对象 + """ + try: + memory = LongTermMemory( + end_user_id=end_user_id, + retrieved_content=retrieved_content or [] + ) + + self.db.add(memory) + self.db.commit() + self.db.refresh(memory) + + db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}") + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建长期记忆记录时出错: {str(e)}") + raise + + def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]: + """根据ID获取长期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + Optional[LongTermMemory]: 记忆对象,如果不存在则返回None + """ + try: + memory = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.id == memory_id) + .first() + ) + + if memory: + db_logger.debug(f"成功查询到长期记忆记录 {memory_id}") + else: + db_logger.debug(f"未找到长期记忆记录 {memory_id}") + + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]: + """根据用户ID获取长期记忆记录列表 + + Args: + end_user_id: 终端用户ID + limit: 返回记录数限制,默认100 + offset: 偏移量,默认0 + + Returns: + List[LongTermMemory]: 记忆记录列表,按创建时间倒序 + """ + try: + # 使用复合索引 ix_memory_long_term_user_time 优化查询 + memories = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .order_by(LongTermMemory.created_at.desc()) + .limit(limit) + .offset(offset) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 的 {len(memories)} 条长期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}") + raise + + def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]: + """根据内容关键词搜索长期记忆记录 + + Args: + end_user_id: 终端用户ID + keyword: 搜索关键词 + limit: 返回记录数限制,默认50 + + Returns: + List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序 + """ + try: + # 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索 + # 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤 + memories = ( + self.db.query(LongTermMemory) + .filter( + LongTermMemory.end_user_id == end_user_id, + LongTermMemory.retrieved_content.astext.contains(keyword) + ) + .order_by(LongTermMemory.created_at.desc()) + .limit(limit) + .all() + ) + + db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}' 的 {len(memories)} 条长期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}") + raise + + def delete_by_id(self, memory_id: uuid.UUID) -> bool: + """删除指定ID的长期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + bool: 删除成功返回True,否则返回False + """ + try: + deleted_count = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.id == memory_id) + .delete(synchronize_session=False) + ) + + self.db.commit() + + if deleted_count > 0: + db_logger.info(f"成功删除长期记忆记录 {memory_id}") + return True + else: + db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def count_by_user_id(self, end_user_id: str) -> int: + """统计用户的长期记忆记录数量 + + Args: + end_user_id: 终端用户ID + + Returns: + int: 记录数量 + """ + try: + count = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .count() + ) + + db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录") + return count + + except Exception as e: + self.db.rollback() + db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}") + raise + + def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]: + """创建或更新长期记忆记录 + + 根据 end_user_id 和 retrieved_content 判断是否需要写入: + - 如果找到相同的 end_user_id 和 retrieved_content,则不写入,返回 None + - 如果没有找到相同的记录,则创建新记录 + + Args: + end_user_id: 终端用户ID + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None + """ + try: + retrieved_content = retrieved_content or [] + + # 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户 + # 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较 + existing_memories = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .order_by(LongTermMemory.created_at.desc()) + .limit(100) # 限制查询数量,避免加载过多数据 + .all() + ) + + # 在 Python 中比较 retrieved_content + for memory in existing_memories: + if memory.retrieved_content == retrieved_content: + # 如果找到相同的记录,不写入 + db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}") + return None + + # 如果没有找到相同的记录,创建新记录 + new_memory = LongTermMemory( + end_user_id=end_user_id, + retrieved_content=retrieved_content + ) + + self.db.add(new_memory) + self.db.commit() + self.db.refresh(new_memory) + + db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}") + return new_memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}") + raise + + diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8193da8a..d44408fe 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -4,6 +4,7 @@ Memory Agent Service Handles business logic for memory agent operations including read/write services, health checks, and message type classification. """ +import datetime import json import os import re @@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType +from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.services.memory_config_service import MemoryConfigService @@ -393,7 +395,7 @@ class MemoryAgentService: import time start_time = time.time() - + ori_message=message # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -406,15 +408,15 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + logger.info(f"Read operation for group {group_id} with config_id {config_id}") - + # 导入审计日志记录器 try: from app.core.memory.utils.log.audit_logger import audit_logger except ImportError: audit_logger = None - + # Get group lock to prevent concurrent processing group_lock = self.get_group_lock(group_id) @@ -430,7 +432,7 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time @@ -442,9 +444,9 @@ class MemoryAgentService: duration=duration, error=error_msg ) - + raise ValueError(error_msg) - + # Step 2: Prepare history history.append({"role": "user", "content": message}) logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") @@ -452,7 +454,7 @@ class MemoryAgentService: # Step 3: Initialize MCP client and execute read workflow mcp_config = get_mcp_server_config() client = MultiServerMCPClient(mcp_config) - + async with client.session('data_flow') as session: logger.debug("Connected to MCP Server: data_flow") tools = await load_mcp_tools(session) @@ -475,7 +477,7 @@ class MemoryAgentService: # Capture any errors from the state if event.get('errors'): workflow_errors.extend(event.get('errors', [])) - + for msg in messages: msg_content = msg.content msg_role = msg.__class__.__name__.lower().replace("message", "") @@ -483,7 +485,7 @@ class MemoryAgentService: "role": msg_role, "content": msg_content }) - + # Extract intermediate outputs if hasattr(msg, 'content'): try: @@ -496,7 +498,7 @@ class MemoryAgentService: break else: continue # No text block found - + # Try to parse content as JSON if isinstance(content_to_parse, str): try: @@ -506,16 +508,16 @@ class MemoryAgentService: if '_intermediate' in parsed: intermediate_data = parsed['_intermediate'] output_key = self._create_intermediate_key(intermediate_data) - + if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - + # Check for multiple intermediate outputs (from Retrieve) if '_intermediates' in parsed: for intermediate_data in parsed['_intermediates']: output_key = self._create_intermediate_key(intermediate_data) - + if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) @@ -523,7 +525,7 @@ class MemoryAgentService: pass except Exception as e: logger.debug(f"Failed to extract intermediate output: {e}") - + workflow_duration = time.time() - start logger.info(f"Read graph workflow completed in {workflow_duration}s") @@ -532,7 +534,7 @@ class MemoryAgentService: for messages in outputs: if messages['role'] == 'tool': message = messages['content'] - + # Handle MCP content format: [{'type': 'text', 'text': '...'}] if isinstance(message, list): # Extract text from MCP content blocks @@ -542,7 +544,7 @@ class MemoryAgentService: break else: continue # No text block found - + try: parsed = json.loads(message) if isinstance(message, str) else message if isinstance(parsed, dict): @@ -552,15 +554,15 @@ class MemoryAgentService: final_answer = summary_result except (json.JSONDecodeError, ValueError): pass - + # 记录成功的操作 total_duration = time.time() - start_time - + # Check for workflow errors if workflow_errors: error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) logger.warning(f"Read workflow completed with errors: {error_details}") - + if audit_logger: audit_logger.log_operation( operation="READ", @@ -577,11 +579,11 @@ class MemoryAgentService: "errors": workflow_errors } ) - + # Raise error if no answer was produced if not final_answer: raise ValueError(f"Read workflow failed: {error_details}") - + if audit_logger and not workflow_errors: audit_logger.log_operation( operation="READ", @@ -596,7 +598,31 @@ class MemoryAgentService: "has_answer": bool(final_answer) } ) - + retrieved_content=[] + repo = ShortTermMemoryRepository(db) + if str(search_switch)!="2": + for intermediate in intermediate_outputs: + intermediate_type=intermediate['type'] + if intermediate_type=="search_result": + query=intermediate['query'] + raw_results=intermediate['raw_results'] + reranked_results=raw_results.get('reranked_results',[]) + statements=[statement['statement'] for statement in reranked_results.get('statements', [])] + statements=list(set(statements)) + retrieved_content.append({query:statements}) + if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]: + # 使用 upsert 方法 + repo.upsert( + end_user_id=group_id, # 确保这个变量在作用域内 + messages=ori_message, + aimessages=final_answer, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + print("写入成功") + + + return { "answer": final_answer, "intermediate_outputs": intermediate_outputs diff --git a/api/app/services/memory_short_service.py b/api/app/services/memory_short_service.py new file mode 100644 index 00000000..ac9f86e0 --- /dev/null +++ b/api/app/services/memory_short_service.py @@ -0,0 +1,56 @@ + +from app.core.logging_config import get_api_logger +from app.db import get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.repositories.memory_short_repository import ShortTermMemoryRepository + + +api_logger = get_api_logger() +db=next(get_db()) +class ShortService: + def __init__(self, end_user_id): + self.short_repo = ShortTermMemoryRepository(db) + self.end_user_id = end_user_id + + def get_short_databasets(self): + short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) + short_result = [] + for memory in short_memories: + deep_expanded = {} # Create a new dictionary for each memory + messages = memory.messages + aimessages = memory.aimessages + retrieved_content = memory.retrieved_content or [] + + api_logger.debug(f"Retrieved content: {retrieved_content}") + + retrieval_source = [] + for item in retrieved_content: + if isinstance(item, dict): + for key, values in item.items(): + retrieval_source.append({"query": key, "retrieval": values}) + + deep_expanded['retrieval'] = retrieval_source + deep_expanded['message'] = messages # 修正拼写错误 + deep_expanded['answer'] = aimessages + short_result.append(deep_expanded) + return short_result + def get_short_count(self): + short_count = self.short_repo.count_by_user_id(self.end_user_id) + return short_count + +class LongService: + def __init__(self, end_user_id): + self.long_repo = LongTermMemoryRepository(db) + self.end_user_id = end_user_id + def get_long_databasets(self): + # 获取长期记忆数据 + long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) + + long_result = [] + for long_memory in long_memories: + if long_memory.retrieved_content: + for memory_item in long_memory.retrieved_content: + if isinstance(memory_item, dict): + for key, values in memory_item.items(): + long_result.append({"query": key, "retrieval": values}) + return long_result diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 40851835..25577cbf 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str field_whitelist = { "Dialogue": ["content", "created_at"], "Chunk": ["content", "created_at"], - "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], - "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], + "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], + "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 }