From 66c153f1ad9e21b69ff4c980b5892abeb0a09948 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 3 Mar 2026 16:48:34 +0800 Subject: [PATCH] refactor(api): improve memory service dependency injection and code organization - Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance - Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports) - Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation - Fix import organization in memory_short_service.py to group related imports and improve readability - Reorganize imports in user_memory_service.py to follow consistent import ordering patterns - Update ShortService instantiation in analytics_memory_types to pass db parameter - Remove module-level db instance initialization in favor of caller-managed database session lifecycle - Add type annotations to method signatures (end_user_id: str, db: Session, return types) - Improve code formatting and spacing consistency across memory service files --- .../memory_short_term_controller.py | 18 +++--- api/app/services/memory_short_service.py | 61 ++++++++++++++----- api/app/services/user_memory_service.py | 25 +++++--- 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 1cca266e..0acac6ce 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -1,16 +1,18 @@ -from fastapi import APIRouter, Depends, HTTPException, status,Header +from typing import Optional + +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, Header, HTTPException, status +from sqlalchemy.orm import Session + from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User - +from app.services.memory_short_service import LongService, ShortService from app.services.memory_storage_service import search_entity -from app.services.memory_short_service import ShortService,LongService -from dotenv import load_dotenv -from sqlalchemy.orm import Session -from typing import Optional + load_dotenv() api_logger = get_api_logger() @@ -29,11 +31,11 @@ async def short_term_configs( language = get_language_from_header(language_type) # 获取短期记忆数据 - short_term=ShortService(end_user_id) + short_term=ShortService(end_user_id, db) short_result=short_term.get_short_databasets() short_count=short_term.get_short_count() - long_term=LongService(end_user_id) + long_term=LongService(end_user_id, db) long_result=long_term.get_long_databasets() entity_result = await search_entity(end_user_id) diff --git a/api/app/services/memory_short_service.py b/api/app/services/memory_short_service.py index fa3870f0..fa9623e0 100644 --- a/api/app/services/memory_short_service.py +++ b/api/app/services/memory_short_service.py @@ -1,22 +1,37 @@ +from typing import Dict, List + +from sqlalchemy.orm import Session from app.core.logging_config import get_api_logger -from app.db import get_db -from app.repositories.memory_short_repository import LongTermMemoryRepository -from app.repositories.memory_short_repository import ShortTermMemoryRepository - +from app.repositories.memory_short_repository import ( + LongTermMemoryRepository, + ShortTermMemoryRepository, +) api_logger = get_api_logger() -db=next(get_db()) + + class ShortService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for short-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.short_repo = ShortTermMemoryRepository(db) self.end_user_id = end_user_id - def get_short_databasets(self): + def get_short_databasets(self) -> List[Dict]: + """Retrieve the latest short-term memory entries for the user. + + Returns: + List[Dict]: List of memory dicts with retrieval, message, and answer keys. + """ short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) short_result = [] for memory in short_memories: - deep_expanded = {} # Create a new dictionary for each memory + deep_expanded = {} messages = memory.messages aimessages = memory.aimessages retrieved_content = memory.retrieved_content or [] @@ -27,23 +42,41 @@ class ShortService: for item in retrieved_content: if isinstance(item, dict): for key, values in item.items(): - retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"}) + retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"}) deep_expanded['retrieval'] = retrieval_source - deep_expanded['message'] = messages # 修正拼写错误 + deep_expanded['message'] = messages deep_expanded['answer'] = aimessages short_result.append(deep_expanded) return short_result - def get_short_count(self): + + def get_short_count(self) -> int: + """Count total short-term memory entries for the user. + + Returns: + int: Number of short-term memory records. + """ short_count = self.short_repo.count_by_user_id(self.end_user_id) return short_count + class LongService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for long-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.long_repo = LongTermMemoryRepository(db) self.end_user_id = end_user_id - def get_long_databasets(self): - # 获取长期记忆数据 + + def get_long_databasets(self) -> List[Dict]: + """Retrieve long-term memory retrieval data for the user. + + Returns: + List[Dict]: List of dicts with query and retrieval keys. + """ long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) long_result = [] diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index e34756b9..db5051d2 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -10,6 +10,9 @@ from collections import Counter from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + from app.core.logging_config import get_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context @@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session logger = get_logger(__name__) @@ -1035,9 +1036,10 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan "growth_trajectory": str # 成长轨迹 } """ - from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt - from app.core.language_utils import validate_language import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt # 验证语言参数 language = validate_language(language) @@ -1161,11 +1163,12 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st "one_sentence": str } """ - from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt - from app.core.language_utils import validate_language - from app.repositories.end_user_repository import EndUserRepository - from app.db import get_db import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt + from app.db import get_db + from app.repositories.end_user_repository import EndUserRepository # 验证语言参数 language = validate_language(language) @@ -1457,7 +1460,7 @@ async def analytics_memory_types( short_term_count = 0 if end_user_id: try: - short_term_service = ShortService(end_user_id) + short_term_service = ShortService(end_user_id, db) short_term_data = short_term_service.get_short_databasets() # 统计 short_term 数组的长度 if short_term_data: @@ -1471,8 +1474,10 @@ async def analytics_memory_types( forgetting_threshold = 0.3 # 默认值 if end_user_id: try: + from app.core.memory.storage_services.forgetting_engine.config_utils import ( + load_actr_config_from_db, + ) from app.services.memory_agent_service import get_end_user_connected_config - from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db # 获取用户关联的 config_id connected_config = get_end_user_connected_config(end_user_id, db)