Files
MemoryBear/api/app/services/memory_config_service.py

717 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
"""
import time
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
)
if TYPE_CHECKING:
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
logger = get_logger(__name__)
config_logger = get_config_logger()
def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
if isinstance(config_id, uuid.UUID):
return config_id
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == str(config_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
return result.config_id
return config_id
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# Try parsing as UUID first
try:
return uuid.UUID(config_id_stripped)
except ValueError:
pass
# Fall back to integer parsing
try:
parsed_id = int(config_id_stripped)
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
return result.config_id
return parsed_id
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected UUID, int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService:
"""
Centralized service for memory configuration loading and validation.
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database with optional fallback.
If config_id is provided, attempts to load that config directly.
If config_id is None or not found and workspace_id is provided,
falls back to the workspace's default configuration.
Args:
config_id: Configuration ID (UUID) from database (optional)
workspace_id: Workspace ID for fallback lookup (optional)
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If no valid configuration can be found
"""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}, workspace_id={workspace_id}")
try:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
if workspace_id:
validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
)
elif config_id:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration not found: config_id={config_id}, workspace_id={workspace_id}"
)
# Get workspace for the config
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
)
memory_config, workspace = result
# Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
workspace_id=workspace.id,
)
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Note: embedding_name is now returned from validate_embedding_model above
# No need for redundant query!
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
config_name=memory_config.config_name,
workspace_id=workspace.id,
workspace_name=workspace.name,
tenant_id=workspace.tenant_id,
llm_model_id=llm_uuid,
llm_model_name=llm_name,
embedding_model_id=embedding_uuid,
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "partial",
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association
scene_id=memory_config.scene_id,
)
elapsed_ms = (time.time() - start_time) * 1000
config_logger.info(
"Memory configuration loaded successfully",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": validated_config_id,
"config_name": config.config_name,
"workspace_id": str(config.workspace_id),
"load_result": "success",
"elapsed_ms": elapsed_ms,
},
)
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
return config
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load memory configuration",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms,
},
exc_info=True,
)
logger.error(f"Failed to load memory configuration {config_id}: {e}")
if isinstance(e, (ConfigurationError, ValueError)):
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from fastapi import status
from fastapi.exceptions import HTTPException
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": str(config.id),
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from fastapi import status
from fastapi.exceptions import HTTPException
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": str(config.id),
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}
def get_ontology_types(self, memory_config: MemoryConfig):
"""Fetch ontology types for the memory configuration's scene.
Args:
memory_config: MemoryConfig object containing scene_id
Returns:
OntologyTypeList if scene_id is valid and has types, None otherwise
"""
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
if not memory_config.scene_id:
logger.debug("No scene_id configured, skipping ontology type fetch")
return None
try:
ontology_repo = OntologyClassRepository(self.db)
ontology_classes = ontology_repo.get_by_scene(memory_config.scene_id)
if not ontology_classes:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
return None
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
return ontology_types
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
exc_info=True
)
return None
def get_workspace_default_config(
self,
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config.
Returns the config marked as default for the workspace. If no explicit
default exists, falls back to the first active config ordered by creation time.
Args:
workspace_id: Workspace ID
Returns:
Optional[MemoryConfigModel]: Default config or None if no configs exist
"""
from sqlalchemy import select
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
# First, try to find the explicitly marked default config
stmt = (
select(MemoryConfigModel)
.where(
MemoryConfigModel.workspace_id == workspace_id,
MemoryConfigModel.is_default.is_(True),
MemoryConfigModel.state.is_(True),
)
.limit(1)
)
config = self.db.scalars(stmt).first()
if config:
return config
# Fallback: get the oldest active config if no explicit default
stmt = (
select(MemoryConfigModel)
.where(
MemoryConfigModel.workspace_id == workspace_id,
MemoryConfigModel.state.is_(True),
)
.order_by(MemoryConfigModel.created_at.asc())
.limit(1)
)
config = self.db.scalars(stmt).first()
if not config:
logger.warning(
"No active memory config found for workspace fallback",
extra={"workspace_id": str(workspace_id)}
)
return config
def get_config_with_fallback(
self,
memory_config_id: Optional[UUID],
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default.
Implements graceful degradation: if the provided config_id is None or
the config doesn't exist, falls back to the workspace's default config.
Args:
memory_config_id: Memory config ID (can be None)
workspace_id: Workspace ID for fallback lookup
Returns:
Optional[MemoryConfigModel]: Memory config or None if no fallback available
"""
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
if not memory_config_id:
logger.debug(
"No memory config ID provided, using workspace default",
extra={"workspace_id": str(workspace_id)}
)
return self.get_workspace_default_config(workspace_id)
config = self.db.get(MemoryConfigModel, memory_config_id)
if config:
return config
logger.warning(
"Memory config not found, falling back to workspace default",
extra={
"missing_config_id": str(memory_config_id),
"workspace_id": str(workspace_id)
}
)
return self.get_workspace_default_config(workspace_id)
def delete_config(
self,
config_id: UUID | int,
force: bool = False
) -> dict:
"""Delete memory config with protection against in-use configs.
Implements delete protection: prevents accidental deletion of configs
that are actively being used by end users or marked as default.
Args:
config_id: Memory config ID to delete (UUID or legacy int)
force: If True, delete even if end users are connected
Returns:
Dict with status, message, and affected_users count
Raises:
ResourceNotFoundException: If config doesn't exist
"""
from app.core.exceptions import ResourceNotFoundException
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
# 处理旧格式 int 类型的 config_id
if isinstance(config_id, int):
logger.warning(
"Attempted to delete legacy int config_id",
extra={"config_id": config_id}
)
return {
"status": "error",
"message": "旧格式配置ID不支持删除操作请使用新版配置",
"legacy_int_id": config_id
}
config = self.db.get(MemoryConfigModel, config_id)
if not config:
raise ResourceNotFoundException("MemoryConfig", str(config_id))
# Check if this is the default config - default configs cannot be deleted
if config.is_default:
logger.warning(
"Attempted to delete default memory config",
extra={"config_id": str(config_id)}
)
return {
"status": "error",
"message": "默认配置不允许删除",
"is_default": True
}
# TODO: add back delete warning
# # Count connected end users
# end_user_repo = EndUserRepository(self.db)
# connected_count = end_user_repo.count_by_memory_config_id(config_id)
# if connected_count > 0 and not force:
# logger.warning(
# "Attempted to delete memory config with connected end users",
# extra={
# "config_id": str(config_id),
# "connected_count": connected_count
# }
# )
# return {
# "status": "warning",
# "message": f"Cannot delete memory config: {connected_count} end users are using it",
# "connected_count": connected_count,
# "force_required": True
# }
# # Force delete: clear end user references first
# if connected_count > 0 and force:
# cleared_count = end_user_repo.clear_memory_config_id(config_id)
# logger.warning(
# "Force deleting memory config",
# extra={
# "config_id": str(config_id),
# "cleared_end_users": cleared_count
# }
# )
connected_count = 0
self.db.delete(config)
self.db.commit()
logger.info(
"Memory config deleted",
extra={
"config_id": str(config_id),
"force": force,
"affected_users": connected_count
}
)
return {
"status": "success",
"message": "Memory config deleted successfully",
"affected_users": connected_count
}