refactor(api): improve memory service dependency injection and code organization
- Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance - Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports) - Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation - Fix import organization in memory_short_service.py to group related imports and improve readability - Reorganize imports in user_memory_service.py to follow consistent import ordering patterns - Update ShortService instantiation in analytics_memory_types to pass db parameter - Remove module-level db instance initialization in favor of caller-managed database session lifecycle - Add type annotations to method signatures (end_user_id: str, db: Session, return types) - Improve code formatting and spacing consistency across memory service files
This commit is contained in:
@@ -1,16 +1,18 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.services.memory_short_service import LongService, ShortService
|
||||||
from app.services.memory_storage_service import search_entity
|
from app.services.memory_storage_service import search_entity
|
||||||
from app.services.memory_short_service import ShortService,LongService
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import Optional
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -29,11 +31,11 @@ async def short_term_configs(
|
|||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
# 获取短期记忆数据
|
# 获取短期记忆数据
|
||||||
short_term=ShortService(end_user_id)
|
short_term=ShortService(end_user_id, db)
|
||||||
short_result=short_term.get_short_databasets()
|
short_result=short_term.get_short_databasets()
|
||||||
short_count=short_term.get_short_count()
|
short_count=short_term.get_short_count()
|
||||||
|
|
||||||
long_term=LongService(end_user_id)
|
long_term=LongService(end_user_id, db)
|
||||||
long_result=long_term.get_long_databasets()
|
long_result=long_term.get_long_databasets()
|
||||||
|
|
||||||
entity_result = await search_entity(end_user_id)
|
entity_result = await search_entity(end_user_id)
|
||||||
|
|||||||
@@ -1,22 +1,37 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.db import get_db
|
from app.repositories.memory_short_repository import (
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
LongTermMemoryRepository,
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
ShortTermMemoryRepository,
|
||||||
|
)
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
db=next(get_db())
|
|
||||||
|
|
||||||
class ShortService:
|
class ShortService:
|
||||||
def __init__(self, end_user_id):
|
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||||
|
"""Service for short-term memory queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: The end user identifier to query memories for.
|
||||||
|
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||||
|
"""
|
||||||
self.short_repo = ShortTermMemoryRepository(db)
|
self.short_repo = ShortTermMemoryRepository(db)
|
||||||
self.end_user_id = end_user_id
|
self.end_user_id = end_user_id
|
||||||
|
|
||||||
def get_short_databasets(self):
|
def get_short_databasets(self) -> List[Dict]:
|
||||||
|
"""Retrieve the latest short-term memory entries for the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: List of memory dicts with retrieval, message, and answer keys.
|
||||||
|
"""
|
||||||
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
||||||
short_result = []
|
short_result = []
|
||||||
for memory in short_memories:
|
for memory in short_memories:
|
||||||
deep_expanded = {} # Create a new dictionary for each memory
|
deep_expanded = {}
|
||||||
messages = memory.messages
|
messages = memory.messages
|
||||||
aimessages = memory.aimessages
|
aimessages = memory.aimessages
|
||||||
retrieved_content = memory.retrieved_content or []
|
retrieved_content = memory.retrieved_content or []
|
||||||
@@ -30,20 +45,38 @@ class ShortService:
|
|||||||
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
|
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
|
||||||
|
|
||||||
deep_expanded['retrieval'] = retrieval_source
|
deep_expanded['retrieval'] = retrieval_source
|
||||||
deep_expanded['message'] = messages # 修正拼写错误
|
deep_expanded['message'] = messages
|
||||||
deep_expanded['answer'] = aimessages
|
deep_expanded['answer'] = aimessages
|
||||||
short_result.append(deep_expanded)
|
short_result.append(deep_expanded)
|
||||||
return short_result
|
return short_result
|
||||||
def get_short_count(self):
|
|
||||||
|
def get_short_count(self) -> int:
|
||||||
|
"""Count total short-term memory entries for the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of short-term memory records.
|
||||||
|
"""
|
||||||
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
||||||
return short_count
|
return short_count
|
||||||
|
|
||||||
|
|
||||||
class LongService:
|
class LongService:
|
||||||
def __init__(self, end_user_id):
|
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||||
|
"""Service for long-term memory queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: The end user identifier to query memories for.
|
||||||
|
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||||
|
"""
|
||||||
self.long_repo = LongTermMemoryRepository(db)
|
self.long_repo = LongTermMemoryRepository(db)
|
||||||
self.end_user_id = end_user_id
|
self.end_user_id = end_user_id
|
||||||
def get_long_databasets(self):
|
|
||||||
# 获取长期记忆数据
|
def get_long_databasets(self) -> List[Dict]:
|
||||||
|
"""Retrieve long-term memory retrieval data for the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: List of dicts with query and retrieval keys.
|
||||||
|
"""
|
||||||
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
||||||
|
|
||||||
long_result = []
|
long_result = []
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ from collections import Counter
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
@@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi
|
|||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
from app.services.memory_short_service import ShortService
|
from app.services.memory_short_service import ShortService
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -1035,10 +1036,11 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan
|
|||||||
"growth_trajectory": str # 成长轨迹
|
"growth_trajectory": str # 成长轨迹
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
|
||||||
from app.core.language_utils import validate_language
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from app.core.language_utils import validate_language
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||||
|
|
||||||
# 验证语言参数
|
# 验证语言参数
|
||||||
language = validate_language(language)
|
language = validate_language(language)
|
||||||
|
|
||||||
@@ -1161,12 +1163,13 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
"one_sentence": str
|
"one_sentence": str
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
|
||||||
from app.core.language_utils import validate_language
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
|
||||||
from app.db import get_db
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from app.core.language_utils import validate_language
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
# 验证语言参数
|
# 验证语言参数
|
||||||
language = validate_language(language)
|
language = validate_language(language)
|
||||||
|
|
||||||
@@ -1457,7 +1460,7 @@ async def analytics_memory_types(
|
|||||||
short_term_count = 0
|
short_term_count = 0
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
short_term_service = ShortService(end_user_id)
|
short_term_service = ShortService(end_user_id, db)
|
||||||
short_term_data = short_term_service.get_short_databasets()
|
short_term_data = short_term_service.get_short_databasets()
|
||||||
# 统计 short_term 数组的长度
|
# 统计 short_term 数组的长度
|
||||||
if short_term_data:
|
if short_term_data:
|
||||||
@@ -1471,8 +1474,10 @@ async def analytics_memory_types(
|
|||||||
forgetting_threshold = 0.3 # 默认值
|
forgetting_threshold = 0.3 # 默认值
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
|
from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||||
|
load_actr_config_from_db,
|
||||||
|
)
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
|
|
||||||
|
|
||||||
# 获取用户关联的 config_id
|
# 获取用户关联的 config_id
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
|||||||
Reference in New Issue
Block a user