[changes] Optimize the semantic pruning judgment rules
This commit is contained in:
@@ -107,28 +107,29 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
)
|
||||
|
||||
|
||||
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
||||
"""从 ontology_class 表加载场景类型名称列表,用于注入提示词。
|
||||
def _load_ontology_class_infos(db: Session, scene_id) -> list:
|
||||
"""从 ontology_class 表加载完整本体类型信息(name + description),用于注入剪枝提示词。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
scene_id: 本体场景 UUID
|
||||
pruning_scene: 语义剪枝场景名称(保留参数,暂未使用)
|
||||
|
||||
Returns:
|
||||
class_name 字符串列表,或 None(无数据时)
|
||||
[{"class_name": ..., "class_description": ...}, ...] 或空列表
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
return []
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
repo = OntologyClassRepository(db)
|
||||
classes = repo.get_classes_by_scene(scene_id)
|
||||
names = [c.class_name for c in classes if c.class_name]
|
||||
return names if names else None
|
||||
return [
|
||||
{"class_name": c.class_name, "class_description": c.class_description or ""}
|
||||
for c in classes if c.class_name
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}")
|
||||
return None
|
||||
logger.warning(f"Failed to load ontology class infos for scene_id={scene_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
@@ -383,7 +384,7 @@ class MemoryConfigService:
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene),
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
@@ -550,11 +551,13 @@ class MemoryConfigService:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
- ontology_class_infos: list of {class_name, class_description} dicts
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"ontology_class_infos": memory_config.ontology_class_infos or [],
|
||||
}
|
||||
|
||||
def get_ontology_types(self, memory_config: MemoryConfig):
|
||||
|
||||
@@ -121,7 +121,7 @@ async def run_pilot_extraction(
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
"ontology_classes": memory_config.ontology_classes,
|
||||
"ontology_class_infos": memory_config.ontology_classes,
|
||||
}
|
||||
config = PruningConfig(**pruning_config_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user