Files
MemoryBear/api/app/services/memory_config_service.py
2026-01-28 14:30:27 +08:00

441 lines
18 KiB
Python

"""
Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
"""
import time
from datetime import datetime
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from sqlalchemy import select
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_model_exists_and_active,
)
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
from uuid import UUID
logger = get_logger(__name__)
config_logger = get_config_logger()
import uuid
def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
if isinstance(config_id, uuid.UUID):
return config_id
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == str(config_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
return result.config_id
return config_id
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# Try parsing as UUID first
try:
return uuid.UUID(config_id_stripped)
except ValueError:
pass
# Fall back to integer parsing
try:
parsed_id = int(config_id_stripped)
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
return result.config_id
return parsed_id
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected UUID, int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService:
"""
Centralized service for memory configuration loading and validation.
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: UUID,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
Args:
config_id: Configuration ID (UUID) from database
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": str(config_id),
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
try:
validated_config_id = _validate_config_id(config_id, self.db)
# Step 1: Get config and workspace
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": str(config_id),
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {config_id} not found in database"
)
memory_config, workspace = result
# Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
workspace_id=workspace.id,
)
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Note: embedding_name is now returned from validate_embedding_model above
# No need for redundant query!
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
config_name=memory_config.config_name,
workspace_id=workspace.id,
workspace_name=workspace.name,
tenant_id=workspace.tenant_id,
llm_model_id=llm_uuid,
llm_model_name=llm_name,
embedding_model_id=embedding_uuid,
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "partial",
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
config_logger.info(
"Memory configuration loaded successfully",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": validated_config_id,
"config_name": config.config_name,
"workspace_id": str(config.workspace_id),
"load_result": "success",
"elapsed_ms": elapsed_ms,
},
)
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
return config
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load memory configuration",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms,
},
exc_info=True,
)
logger.error(f"Failed to load memory configuration {config_id}: {e}")
if isinstance(e, (ConfigurationError, ValueError)):
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}