refactor(api): improve memory service dependency injection and code organization
- Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance - Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports) - Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation - Fix import organization in memory_short_service.py to group related imports and improve readability - Reorganize imports in user_memory_service.py to follow consistent import ordering patterns - Update ShortService instantiation in analytics_memory_types to pass db parameter - Remove module-level db instance initialization in favor of caller-managed database session lifecycle - Add type annotations to method signatures (end_user_id: str, db: Session, return types) - Improve code formatting and spacing consistency across memory service files
This commit is contained in:
@@ -1,22 +1,37 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
|
||||
from app.repositories.memory_short_repository import (
|
||||
LongTermMemoryRepository,
|
||||
ShortTermMemoryRepository,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
db=next(get_db())
|
||||
|
||||
|
||||
class ShortService:
|
||||
def __init__(self, end_user_id):
|
||||
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||
"""Service for short-term memory queries.
|
||||
|
||||
Args:
|
||||
end_user_id: The end user identifier to query memories for.
|
||||
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||
"""
|
||||
self.short_repo = ShortTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
|
||||
def get_short_databasets(self):
|
||||
def get_short_databasets(self) -> List[Dict]:
|
||||
"""Retrieve the latest short-term memory entries for the user.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of memory dicts with retrieval, message, and answer keys.
|
||||
"""
|
||||
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
||||
short_result = []
|
||||
for memory in short_memories:
|
||||
deep_expanded = {} # Create a new dictionary for each memory
|
||||
deep_expanded = {}
|
||||
messages = memory.messages
|
||||
aimessages = memory.aimessages
|
||||
retrieved_content = memory.retrieved_content or []
|
||||
@@ -27,23 +42,41 @@ class ShortService:
|
||||
for item in retrieved_content:
|
||||
if isinstance(item, dict):
|
||||
for key, values in item.items():
|
||||
retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"})
|
||||
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
|
||||
|
||||
deep_expanded['retrieval'] = retrieval_source
|
||||
deep_expanded['message'] = messages # 修正拼写错误
|
||||
deep_expanded['message'] = messages
|
||||
deep_expanded['answer'] = aimessages
|
||||
short_result.append(deep_expanded)
|
||||
return short_result
|
||||
def get_short_count(self):
|
||||
|
||||
def get_short_count(self) -> int:
|
||||
"""Count total short-term memory entries for the user.
|
||||
|
||||
Returns:
|
||||
int: Number of short-term memory records.
|
||||
"""
|
||||
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
||||
return short_count
|
||||
|
||||
|
||||
class LongService:
|
||||
def __init__(self, end_user_id):
|
||||
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||
"""Service for long-term memory queries.
|
||||
|
||||
Args:
|
||||
end_user_id: The end user identifier to query memories for.
|
||||
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||
"""
|
||||
self.long_repo = LongTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
def get_long_databasets(self):
|
||||
# 获取长期记忆数据
|
||||
|
||||
def get_long_databasets(self) -> List[Dict]:
|
||||
"""Retrieve long-term memory retrieval data for the user.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of dicts with query and retrieval keys.
|
||||
"""
|
||||
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
||||
|
||||
long_result = []
|
||||
|
||||
Reference in New Issue
Block a user