feat(memory): support perception-aware memory writing in workflow and Neo4j nodes

This commit is contained in:
Eternity
2026-03-23 16:33:25 +08:00
parent 31085ed678
commit 2ff81ba101
22 changed files with 820 additions and 519 deletions

View File

@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
if isinstance(config_id, uuid.UUID):
return config_id
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
if result:
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
return result.config_id
return config_id
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# Try parsing as UUID first
try:
return uuid.UUID(config_id_stripped)
except ValueError:
pass
# Fall back to integer parsing
try:
parsed_id = int(config_id_stripped)
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
return result.config_id
return parsed_id
except ValueError:
raise InvalidConfigError(
@@ -154,10 +154,10 @@ class MemoryConfigService:
self.db = db
def load_memory_config(
self,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
self,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database with optional fallback.
@@ -194,14 +194,14 @@ class MemoryConfigService:
try:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
validated_config_id = None
if workspace_id:
validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
@@ -210,7 +210,7 @@ class MemoryConfigService:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -233,7 +233,7 @@ class MemoryConfigService:
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
@@ -243,10 +243,10 @@ class MemoryConfigService:
# Helper function to validate model with workspace fallback
def _validate_model_with_fallback(
model_id: str,
model_type: str,
workspace_default: str,
required: bool = False
model_id: str,
model_type: str,
workspace_default: str,
required: bool = False
) -> tuple:
"""Validate model ID, falling back to workspace default if invalid.
@@ -275,7 +275,7 @@ class MemoryConfigService:
logger.warning(
f"{model_type} model validation failed, trying workspace default: {e}"
)
# Fallback to workspace default
if workspace_default:
try:
@@ -297,7 +297,7 @@ class MemoryConfigService:
logger.error(f"Workspace default {model_type} model also invalid: {e}")
if required:
raise
if required:
raise InvalidConfigError(
f"{model_type.title()} model is required but not configured",
@@ -306,7 +306,7 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id
)
return None, None
# Step 2: Validate embedding model with workspace fallback
@@ -343,6 +343,35 @@ class MemoryConfigService:
if memory_config.rerank_id or workspace.rerank:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
vision_uuid, vision_name = validate_and_resolve_model_id(
memory_config.vision_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
audio_uuid, audio_name = validate_and_resolve_model_id(
memory_config.audio_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
video_uuid, video_name = validate_and_resolve_model_id(
memory_config.video_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
@@ -356,6 +385,12 @@ class MemoryConfigService:
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
video_model_id=video_uuid,
video_model_name=video_name,
vision_model_id=vision_uuid,
vision_model_name=vision_name,
audio_model_id=audio_uuid,
audio_model_name=audio_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
@@ -364,24 +399,31 @@ class MemoryConfigService:
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
enable_llm_dedup_blockwise=bool(
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
statement_granularity=int(
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(
memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_enabled=bool(
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
pruning_threshold=float(
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association
scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -448,9 +490,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -481,9 +523,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -571,25 +613,25 @@ class MemoryConfigService:
"""
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
if not memory_config.scene_id:
logger.debug("No scene_id configured, skipping ontology type fetch")
return None
try:
ontology_repo = OntologyClassRepository(self.db)
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
if not ontology_classes:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
return None
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
return ontology_types
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
@@ -598,8 +640,8 @@ class MemoryConfigService:
return None
def get_workspace_default_config(
self,
workspace_id: UUID
self,
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config.
@@ -613,19 +655,19 @@ class MemoryConfigService:
Optional[MemoryConfigModel]: Default config or None if no configs exist
"""
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
if not config:
logger.warning(
"No active memory config found for workspace fallback",
extra={"workspace_id": str(workspace_id)}
)
return config
def get_config_with_fallback(
self,
memory_config_id: Optional[UUID],
workspace_id: UUID
self,
memory_config_id: Optional[UUID],
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default.
@@ -644,13 +686,13 @@ class MemoryConfigService:
"No memory config ID provided, using workspace default",
extra={"workspace_id": str(workspace_id)}
)
config = MemoryConfigRepository.get_with_fallback(
self.db,
memory_config_id,
workspace_id
)
if not config and memory_config_id:
logger.warning(
"Memory config not found, falling back to workspace default",
@@ -659,13 +701,13 @@ class MemoryConfigService:
"workspace_id": str(workspace_id)
}
)
return config
def delete_config(
self,
config_id: UUID | int,
force: bool = False
self,
config_id: UUID | int,
force: bool = False
) -> dict:
"""Delete memory config with protection against in-use configs.
@@ -687,7 +729,7 @@ class MemoryConfigService:
from app.core.exceptions import ResourceNotFoundException
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.end_user_repository import EndUserRepository
# 处理旧格式 int 类型的 config_id
if isinstance(config_id, int):
logger.warning(
@@ -699,11 +741,11 @@ class MemoryConfigService:
"message": "旧格式配置ID不支持删除操作请使用新版配置",
"legacy_int_id": config_id
}
config = self.db.get(MemoryConfigModel, config_id)
if not config:
raise ResourceNotFoundException("MemoryConfig", str(config_id))
# Check if this is the default config - default configs cannot be deleted
if config.is_default:
logger.warning(
@@ -715,11 +757,11 @@ class MemoryConfigService:
"message": "默认配置不允许删除",
"is_default": True
}
# Use repository to count connected end users
end_user_repo = EndUserRepository(self.db)
connected_count = end_user_repo.count_by_memory_config_id(config_id)
if connected_count > 0 and not force:
logger.warning(
"Attempted to delete memory config with connected end users",
@@ -728,18 +770,18 @@ class MemoryConfigService:
"connected_count": connected_count
}
)
return {
"status": "warning",
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
"connected_count": connected_count,
"force_required": True
}
# Force delete: use repository to clear end user references first
if connected_count > 0 and force:
cleared_count = end_user_repo.clear_memory_config_id(config_id)
logger.warning(
"Force deleting memory config, clearing end user references",
extra={
@@ -747,11 +789,11 @@ class MemoryConfigService:
"cleared_end_users": cleared_count
}
)
try:
self.db.delete(config)
self.db.commit()
logger.info(
"Memory config deleted",
extra={
@@ -760,16 +802,16 @@ class MemoryConfigService:
"affected_users": connected_count
}
)
return {
"status": "success",
"message": "记忆配置删除成功",
"affected_users": connected_count
}
except IntegrityError as e:
self.db.rollback()
# Handle foreign key violation gracefully
error_str = str(e.orig) if e.orig else str(e)
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
@@ -785,7 +827,7 @@ class MemoryConfigService:
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
"force_required": True
}
# Re-raise other integrity errors
logger.error(
"Delete failed due to integrity error",
@@ -800,9 +842,9 @@ class MemoryConfigService:
# ==================== 记忆配置提取方法 ====================
def extract_memory_config_id(
self,
app_type: str,
config: dict
self,
app_type: str,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id根据应用类型分发
@@ -828,8 +870,8 @@ class MemoryConfigService:
return None, False
def _extract_memory_config_id_from_agent(
self,
config: dict
self,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Agent 应用配置中提取 memory_config_id
@@ -888,8 +930,8 @@ class MemoryConfigService:
return None, False
def _extract_memory_config_id_from_workflow(
self,
config: dict
self,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Workflow 应用配置中提取 memory_config_id
@@ -905,14 +947,14 @@ class MemoryConfigService:
- is_legacy_int: 是否检测到旧格式 int 数据
"""
nodes = config.get("nodes", [])
for node in nodes:
node_type = node.get("type", "")
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
config_id = node.get("config", {}).get("config_id")
if config_id:
try:
# 处理字符串、UUID 和 int旧数据兼容三种情况
@@ -937,6 +979,6 @@ class MemoryConfigService:
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
f"node_type={node_type}, error={str(e)}"
)
logger.debug("工作流配置中未找到记忆节点")
return None, False