feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||
if isinstance(config_id, uuid.UUID):
|
||||
return config_id
|
||||
|
||||
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
config_id_stripped = config_id.strip()
|
||||
|
||||
|
||||
# Try parsing as UUID first
|
||||
try:
|
||||
return uuid.UUID(config_id_stripped)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# Fall back to integer parsing
|
||||
try:
|
||||
parsed_id = int(config_id_stripped)
|
||||
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
|
||||
if db is not None:
|
||||
# 查询 user_id 匹配的记录
|
||||
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
|
||||
result = db.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database with optional fallback.
|
||||
@@ -194,14 +194,14 @@ class MemoryConfigService:
|
||||
try:
|
||||
# Use get_config_with_fallback if workspace_id is provided
|
||||
memory_config = None
|
||||
validated_config_id = None
|
||||
if workspace_id:
|
||||
validated_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
except Exception:
|
||||
validated_config_id = None
|
||||
|
||||
|
||||
memory_config = self.get_config_with_fallback(
|
||||
memory_config_id=validated_config_id,
|
||||
workspace_id=workspace_id
|
||||
@@ -210,7 +210,7 @@ class MemoryConfigService:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
|
||||
|
||||
|
||||
if not memory_config:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -233,7 +233,7 @@ class MemoryConfigService:
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
|
||||
|
||||
if not result:
|
||||
raise ConfigurationError(
|
||||
f"Workspace not found for config {memory_config.config_id}"
|
||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
||||
|
||||
# Helper function to validate model with workspace fallback
|
||||
def _validate_model_with_fallback(
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
) -> tuple:
|
||||
"""Validate model ID, falling back to workspace default if invalid.
|
||||
|
||||
@@ -275,7 +275,7 @@ class MemoryConfigService:
|
||||
logger.warning(
|
||||
f"{model_type} model validation failed, trying workspace default: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Fallback to workspace default
|
||||
if workspace_default:
|
||||
try:
|
||||
@@ -297,7 +297,7 @@ class MemoryConfigService:
|
||||
logger.error(f"Workspace default {model_type} model also invalid: {e}")
|
||||
if required:
|
||||
raise
|
||||
|
||||
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model is required but not configured",
|
||||
@@ -306,7 +306,7 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id
|
||||
)
|
||||
|
||||
|
||||
return None, None
|
||||
|
||||
# Step 2: Validate embedding model with workspace fallback
|
||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
||||
if memory_config.rerank_id or workspace.rerank:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||
memory_config.vision_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||
memory_config.audio_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
video_uuid, video_name = validate_and_resolve_model_id(
|
||||
memory_config.video_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
video_model_id=video_uuid,
|
||||
video_model_name=video_name,
|
||||
vision_model_id=vision_uuid,
|
||||
vision_model_name=vision_name,
|
||||
audio_model_id=audio_uuid,
|
||||
audio_model_name=audio_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
||||
reflexion_baseline=memory_config.baseline or "Time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
enable_llm_dedup_blockwise=bool(
|
||||
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(
|
||||
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
statement_granularity=int(
|
||||
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(
|
||||
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(
|
||||
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_enabled=bool(
|
||||
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
pruning_threshold=float(
|
||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
@@ -448,9 +490,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -481,9 +523,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -571,25 +613,25 @@ class MemoryConfigService:
|
||||
"""
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
if not memory_config.scene_id:
|
||||
logger.debug("No scene_id configured, skipping ontology type fetch")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
ontology_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
|
||||
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
return None
|
||||
|
||||
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
return ontology_types
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
||||
return None
|
||||
|
||||
def get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
self,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get workspace default memory config.
|
||||
|
||||
@@ -613,19 +655,19 @@ class MemoryConfigService:
|
||||
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
||||
"""
|
||||
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
|
||||
|
||||
|
||||
if not config:
|
||||
logger.warning(
|
||||
"No active memory config found for workspace fallback",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def get_config_with_fallback(
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get memory config with fallback to workspace default.
|
||||
|
||||
@@ -644,13 +686,13 @@ class MemoryConfigService:
|
||||
"No memory config ID provided, using workspace default",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
config = MemoryConfigRepository.get_with_fallback(
|
||||
self.db,
|
||||
memory_config_id,
|
||||
workspace_id
|
||||
)
|
||||
|
||||
|
||||
if not config and memory_config_id:
|
||||
logger.warning(
|
||||
"Memory config not found, falling back to workspace default",
|
||||
@@ -659,13 +701,13 @@ class MemoryConfigService:
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def delete_config(
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
) -> dict:
|
||||
"""Delete memory config with protection against in-use configs.
|
||||
|
||||
@@ -687,7 +729,7 @@ class MemoryConfigService:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
|
||||
# 处理旧格式 int 类型的 config_id
|
||||
if isinstance(config_id, int):
|
||||
logger.warning(
|
||||
@@ -699,11 +741,11 @@ class MemoryConfigService:
|
||||
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
|
||||
"legacy_int_id": config_id
|
||||
}
|
||||
|
||||
|
||||
config = self.db.get(MemoryConfigModel, config_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
||||
|
||||
|
||||
# Check if this is the default config - default configs cannot be deleted
|
||||
if config.is_default:
|
||||
logger.warning(
|
||||
@@ -715,11 +757,11 @@ class MemoryConfigService:
|
||||
"message": "默认配置不允许删除",
|
||||
"is_default": True
|
||||
}
|
||||
|
||||
|
||||
# Use repository to count connected end users
|
||||
end_user_repo = EndUserRepository(self.db)
|
||||
connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
||||
|
||||
|
||||
if connected_count > 0 and not force:
|
||||
logger.warning(
|
||||
"Attempted to delete memory config with connected end users",
|
||||
@@ -728,18 +770,18 @@ class MemoryConfigService:
|
||||
"connected_count": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
|
||||
"connected_count": connected_count,
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Force delete: use repository to clear end user references first
|
||||
if connected_count > 0 and force:
|
||||
cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
||||
|
||||
|
||||
logger.warning(
|
||||
"Force deleting memory config, clearing end user references",
|
||||
extra={
|
||||
@@ -747,11 +789,11 @@ class MemoryConfigService:
|
||||
"cleared_end_users": cleared_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
self.db.delete(config)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
logger.info(
|
||||
"Memory config deleted",
|
||||
extra={
|
||||
@@ -760,16 +802,16 @@ class MemoryConfigService:
|
||||
"affected_users": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "记忆配置删除成功",
|
||||
"affected_users": connected_count
|
||||
}
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
|
||||
|
||||
# Handle foreign key violation gracefully
|
||||
error_str = str(e.orig) if e.orig else str(e)
|
||||
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
|
||||
@@ -785,7 +827,7 @@ class MemoryConfigService:
|
||||
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Re-raise other integrity errors
|
||||
logger.error(
|
||||
"Delete failed due to integrity error",
|
||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||
|
||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_agent(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Agent 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_workflow(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Workflow 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -905,14 +947,14 @@ class MemoryConfigService:
|
||||
- is_legacy_int: 是否检测到旧格式 int 数据
|
||||
"""
|
||||
nodes = config.get("nodes", [])
|
||||
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type", "")
|
||||
|
||||
|
||||
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
|
||||
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
|
||||
config_id = node.get("config", {}).get("config_id")
|
||||
|
||||
|
||||
if config_id:
|
||||
try:
|
||||
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
|
||||
@@ -937,6 +979,6 @@ class MemoryConfigService:
|
||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||
f"node_type={node_type}, error={str(e)}"
|
||||
)
|
||||
|
||||
|
||||
logger.debug("工作流配置中未找到记忆节点")
|
||||
return None, False
|
||||
|
||||
Reference in New Issue
Block a user