refactor(api): improve memory service dependency injection and code organization

- Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance
- Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports)
- Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation
- Fix import organization in memory_short_service.py to group related imports and improve readability
- Reorganize imports in user_memory_service.py to follow consistent import ordering patterns
- Update ShortService instantiation in analytics_memory_types to pass db parameter
- Remove module-level db instance initialization in favor of caller-managed database session lifecycle
- Add type annotations to method signatures (end_user_id: str, db: Session, return types)
- Improve code formatting and spacing consistency across memory service files
This commit is contained in:
Ke Sun
2026-03-03 16:48:34 +08:00
parent c6c7a1827c
commit 66c153f1ad
3 changed files with 72 additions and 32 deletions

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, status,Header from typing import Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy.orm import Session
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv() load_dotenv()
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -29,11 +31,11 @@ async def short_term_configs(
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
# 获取短期记忆数据 # 获取短期记忆数据
short_term=ShortService(end_user_id) short_term=ShortService(end_user_id, db)
short_result=short_term.get_short_databasets() short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count() short_count=short_term.get_short_count()
long_term=LongService(end_user_id) long_term=LongService(end_user_id, db)
long_result=long_term.get_long_databasets() long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id) entity_result = await search_entity(end_user_id)

View File

@@ -1,22 +1,37 @@
from typing import Dict, List
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.db import get_db from app.repositories.memory_short_repository import (
from app.repositories.memory_short_repository import LongTermMemoryRepository LongTermMemoryRepository,
from app.repositories.memory_short_repository import ShortTermMemoryRepository ShortTermMemoryRepository,
)
api_logger = get_api_logger() api_logger = get_api_logger()
db=next(get_db())
class ShortService: class ShortService:
def __init__(self, end_user_id): def __init__(self, end_user_id: str, db: Session) -> None:
"""Service for short-term memory queries.
Args:
end_user_id: The end user identifier to query memories for.
db: SQLAlchemy database session (caller-managed lifecycle).
"""
self.short_repo = ShortTermMemoryRepository(db) self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id self.end_user_id = end_user_id
def get_short_databasets(self): def get_short_databasets(self) -> List[Dict]:
"""Retrieve the latest short-term memory entries for the user.
Returns:
List[Dict]: List of memory dicts with retrieval, message, and answer keys.
"""
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = [] short_result = []
for memory in short_memories: for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory deep_expanded = {}
messages = memory.messages messages = memory.messages
aimessages = memory.aimessages aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or [] retrieved_content = memory.retrieved_content or []
@@ -30,20 +45,38 @@ class ShortService:
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"}) retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
deep_expanded['retrieval'] = retrieval_source deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误 deep_expanded['message'] = messages
deep_expanded['answer'] = aimessages deep_expanded['answer'] = aimessages
short_result.append(deep_expanded) short_result.append(deep_expanded)
return short_result return short_result
def get_short_count(self):
def get_short_count(self) -> int:
"""Count total short-term memory entries for the user.
Returns:
int: Number of short-term memory records.
"""
short_count = self.short_repo.count_by_user_id(self.end_user_id) short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count return short_count
class LongService: class LongService:
def __init__(self, end_user_id): def __init__(self, end_user_id: str, db: Session) -> None:
"""Service for long-term memory queries.
Args:
end_user_id: The end user identifier to query memories for.
db: SQLAlchemy database session (caller-managed lifecycle).
"""
self.long_repo = LongTermMemoryRepository(db) self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据 def get_long_databasets(self) -> List[Dict]:
"""Retrieve long-term memory retrieval data for the user.
Returns:
List[Dict]: List of dicts with query and retrieval keys.
"""
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = [] long_result = []

View File

@@ -10,6 +10,9 @@ from collections import Counter
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
@@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -1035,10 +1036,11 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan
"growth_trajectory": str # 成长轨迹 "growth_trajectory": str # 成长轨迹
} }
""" """
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
from app.core.language_utils import validate_language
import re import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
# 验证语言参数 # 验证语言参数
language = validate_language(language) language = validate_language(language)
@@ -1161,12 +1163,13 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
"one_sentence": str "one_sentence": str
} }
""" """
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.core.language_utils import validate_language
from app.repositories.end_user_repository import EndUserRepository
from app.db import get_db
import re import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数 # 验证语言参数
language = validate_language(language) language = validate_language(language)
@@ -1457,7 +1460,7 @@ async def analytics_memory_types(
short_term_count = 0 short_term_count = 0
if end_user_id: if end_user_id:
try: try:
short_term_service = ShortService(end_user_id) short_term_service = ShortService(end_user_id, db)
short_term_data = short_term_service.get_short_databasets() short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度 # 统计 short_term 数组的长度
if short_term_data: if short_term_data:
@@ -1471,8 +1474,10 @@ async def analytics_memory_types(
forgetting_threshold = 0.3 # 默认值 forgetting_threshold = 0.3 # 默认值
if end_user_id: if end_user_id:
try: try:
from app.core.memory.storage_services.forgetting_engine.config_utils import (
load_actr_config_from_db,
)
from app.services.memory_agent_service import get_end_user_connected_config from app.services.memory_agent_service import get_end_user_connected_config
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
# 获取用户关联的 config_id # 获取用户关联的 config_id
connected_config = get_end_user_connected_config(end_user_id, db) connected_config = get_end_user_connected_config(end_user_id, db)