refactor(memory): restructure memory system and improve configuration management
- Remove deprecated main.py entry point from memory module - Reorganize imports across controllers and services for consistency - Update emotion controller to pass db session instead of config_id to services - Enhance memory agent controller with db session parameter for status_type and user_profile endpoints - Refactor memory agent service to accept db parameter in classify_message_type method - Improve configuration handling in celery_app by removing automatic database reload - Update all memory-related services to use centralized config management - Standardize import ordering and remove unused imports across 50+ files - Add pilot_run_service for new pilot execution workflow - Refactor extraction engine, reflection engine, and search services for better modularity - Update LLM utilities and embedder configuration for improved flexibility - Enhance type classifier and verification tools with better error handling - Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
@@ -3,7 +3,6 @@ Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
@@ -77,19 +76,29 @@ class MemoryConfigService:
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
Database session management is handled internally.
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method manages its own database session internally.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
@@ -100,27 +109,6 @@ class MemoryConfigService:
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
from app.db import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
return MemoryConfigService._load_memory_config_with_db(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
service_name=service_name,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _load_memory_config_with_db(
|
||||
config_id: int,
|
||||
db: Session,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""Internal method that loads memory configuration with an existing db session."""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
@@ -137,7 +125,7 @@ class MemoryConfigService:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -160,7 +148,7 @@ class MemoryConfigService:
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
@@ -169,7 +157,7 @@ class MemoryConfigService:
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
@@ -183,7 +171,7 @@ class MemoryConfigService:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
@@ -194,7 +182,7 @@ class MemoryConfigService:
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
@@ -220,6 +208,25 @@ class MemoryConfigService:
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
@@ -262,3 +269,131 @@ class MemoryConfigService:
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict:
|
||||
"""Get LLM model configuration by ID.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
}
|
||||
|
||||
def get_embedder_config(self, embedding_id: str) -> dict:
|
||||
"""Get embedding model configuration by ID.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_config(memory_config: MemoryConfig):
|
||||
"""Build ExtractionPipelineConfig from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing all pipeline settings.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings.
|
||||
"""
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
|
||||
dedup_config = DedupConfig(
|
||||
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
|
||||
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
|
||||
fuzzy_name_threshold_strict=memory_config.t_name_strict,
|
||||
fuzzy_type_threshold_strict=memory_config.t_type_strict,
|
||||
fuzzy_overall_threshold=memory_config.t_overall,
|
||||
)
|
||||
|
||||
stmt_config = StatementExtractionConfig(
|
||||
statement_granularity=memory_config.statement_granularity,
|
||||
include_dialogue_context=memory_config.include_dialogue_context,
|
||||
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
|
||||
)
|
||||
|
||||
forget_config = ForgettingEngineConfig(
|
||||
offset=memory_config.offset,
|
||||
lambda_time=memory_config.lambda_time,
|
||||
lambda_mem=memory_config.lambda_mem,
|
||||
)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pruning_config(memory_config: MemoryConfig) -> dict:
|
||||
"""Retrieve semantic pruning config from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing pruning settings.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user