refactor(memory): restructure memory system and improve configuration management

- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
Ke Sun
2025-12-23 17:17:04 +08:00
parent 258b88276f
commit 283c64a358
58 changed files with 2171 additions and 1797 deletions

View File

@@ -3,7 +3,6 @@ Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
Database session management is handled internally.
"""
import time
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
invalid_value=config_id,
)
return parsed_id
except ValueError as e:
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
@@ -77,19 +76,29 @@ class MemoryConfigService:
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Database session management is handled internally.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
@staticmethod
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method manages its own database session internally.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
@@ -100,27 +109,6 @@ class MemoryConfigService:
Raises:
ConfigurationError: If validation fails
"""
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
try:
return MemoryConfigService._load_memory_config_with_db(
config_id=config_id,
db=db,
service_name=service_name,
)
finally:
db.close()
@staticmethod
def _load_memory_config_with_db(
config_id: int,
db: Session,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""Internal method that loads memory configuration with an existing db session."""
start_time = time.time()
config_logger.info(
@@ -137,7 +125,7 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -160,7 +148,7 @@ class MemoryConfigService:
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
db,
self.db,
workspace.tenant_id,
workspace.id,
)
@@ -169,7 +157,7 @@ class MemoryConfigService:
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
db,
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
@@ -183,7 +171,7 @@ class MemoryConfigService:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
db,
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
@@ -194,7 +182,7 @@ class MemoryConfigService:
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
db,
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
@@ -220,6 +208,25 @@ class MemoryConfigService:
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
@@ -262,3 +269,131 @@ class MemoryConfigService:
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}