refactor(memory): deduplicate assistant alias query and fix case-sensitive placeholder matching
- Extract fetch_neo4j_assistant_aliases() into deduped_and_disamb.py as single source of truth, replacing inline Cypher in write_tools and extraction_orchestrator - Normalize USER_PLACEHOLDER_NAMES to lowercase and apply .lower() on all comparisons to prevent case-variant names leaking into aliases
This commit is contained in:
@@ -158,20 +158,13 @@ async def write(
|
|||||||
try:
|
try:
|
||||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||||
clean_cross_role_aliases,
|
clean_cross_role_aliases,
|
||||||
|
fetch_neo4j_assistant_aliases,
|
||||||
)
|
)
|
||||||
neo4j_assistant_aliases = set()
|
neo4j_assistant_aliases = set()
|
||||||
if all_entity_nodes:
|
if all_entity_nodes:
|
||||||
_eu_id = all_entity_nodes[0].end_user_id
|
_eu_id = all_entity_nodes[0].end_user_id
|
||||||
if _eu_id:
|
if _eu_id:
|
||||||
_cypher = """
|
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||||
MATCH (e:ExtractedEntity)
|
|
||||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['AI助手', '助手', 'AI Assistant', 'Assistant']
|
|
||||||
RETURN e.aliases AS aliases
|
|
||||||
"""
|
|
||||||
_result = await neo4j_connector.execute_query(_cypher, end_user_id=_eu_id)
|
|
||||||
for _record in (_result or []):
|
|
||||||
for _alias in (_record.get('aliases') or []):
|
|
||||||
neo4j_assistant_aliases.add(_alias.strip().lower())
|
|
||||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -286,6 +286,51 @@ def _normalize_special_entity_names(
|
|||||||
ent.aliases = cleaned
|
ent.aliases = cleaned
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
|
||||||
|
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
|
||||||
|
|
||||||
|
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
|
||||||
|
避免多处维护相同的 Cypher 和名称列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
|
||||||
|
end_user_id: 终端用户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
小写归一化后的助手别名集合
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 使用模块级 _ASSISTANT_PLACEHOLDER_NAMES 的标题化形式构建查询名称列表,
|
||||||
|
# 保持与 _normalize_special_entity_names 标准化后的名称一致
|
||||||
|
query_names = [_CANONICAL_ASSISTANT_NAME] # "AI助手"
|
||||||
|
# 补充英文常见变体
|
||||||
|
query_names.extend(["助手", "AI Assistant", "Assistant"])
|
||||||
|
# 去重
|
||||||
|
query_names = list(dict.fromkeys(query_names))
|
||||||
|
|
||||||
|
cypher = """
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.end_user_id = $end_user_id AND e.name IN $names
|
||||||
|
RETURN e.aliases AS aliases
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await neo4j_connector.execute_query(
|
||||||
|
cypher, end_user_id=end_user_id, names=query_names
|
||||||
|
)
|
||||||
|
assistant_aliases: set = set()
|
||||||
|
for record in (result or []):
|
||||||
|
for alias in (record.get("aliases") or []):
|
||||||
|
assistant_aliases.add(alias.strip().lower())
|
||||||
|
if assistant_aliases:
|
||||||
|
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
||||||
|
return assistant_aliases
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
def clean_cross_role_aliases(
|
def clean_cross_role_aliases(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
external_assistant_aliases: set = None,
|
external_assistant_aliases: set = None,
|
||||||
|
|||||||
@@ -1400,7 +1400,7 @@ class ExtractionOrchestrator:
|
|||||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||||
db_aliases = (info.aliases if info and info.aliases else [])
|
db_aliases = (info.aliases if info and info.aliases else [])
|
||||||
# 过滤掉占位名称
|
# 过滤掉占位名称
|
||||||
db_aliases = [a for a in db_aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
|
||||||
# 合并:已有 + 本轮新增,去重保序
|
# 合并:已有 + 本轮新增,去重保序
|
||||||
merged_aliases = list(db_aliases)
|
merged_aliases = list(db_aliases)
|
||||||
@@ -1440,7 +1440,7 @@ class ExtractionOrchestrator:
|
|||||||
else:
|
else:
|
||||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||||
# 确保 first_alias 不是占位名称
|
# 确保 first_alias 不是占位名称
|
||||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||||
db.add(EndUserInfo(
|
db.add(EndUserInfo(
|
||||||
end_user_id=end_user_uuid,
|
end_user_id=end_user_uuid,
|
||||||
other_name=first_alias,
|
other_name=first_alias,
|
||||||
@@ -1457,7 +1457,7 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
|
|
||||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
USER_PLACEHOLDER_NAMES = {'用户', '我', 'user', 'i'}
|
||||||
|
|
||||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
|
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
|
||||||
"""从用户发言的原始实体中提取别名(绕过去重污染)
|
"""从用户发言的原始实体中提取别名(绕过去重污染)
|
||||||
@@ -1490,10 +1490,10 @@ class ExtractionOrchestrator:
|
|||||||
continue
|
continue
|
||||||
for entity in (triplet_info.entities or []):
|
for entity in (triplet_info.entities or []):
|
||||||
ent_name = getattr(entity, 'name', '').strip()
|
ent_name = getattr(entity, 'name', '').strip()
|
||||||
if ent_name in self.USER_PLACEHOLDER_NAMES:
|
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||||
for alias in (getattr(entity, 'aliases', []) or []):
|
for alias in (getattr(entity, 'aliases', []) or []):
|
||||||
a = alias.strip()
|
a = alias.strip()
|
||||||
if a and a not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
|
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
|
||||||
all_user_aliases.append(a)
|
all_user_aliases.append(a)
|
||||||
seen_lower.add(a.lower())
|
seen_lower.add(a.lower())
|
||||||
if all_user_aliases:
|
if all_user_aliases:
|
||||||
@@ -1502,11 +1502,11 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
# 兜底:从去重后的 entity_nodes 提取(旧逻辑)
|
# 兜底:从去重后的 entity_nodes 提取(旧逻辑)
|
||||||
for entity in entity_nodes:
|
for entity in entity_nodes:
|
||||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||||
aliases = getattr(entity, 'aliases', []) or []
|
aliases = getattr(entity, 'aliases', []) or []
|
||||||
filtered = [
|
filtered = [
|
||||||
a for a in aliases
|
a for a in aliases
|
||||||
if a.strip() not in self.USER_PLACEHOLDER_NAMES
|
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
|
||||||
]
|
]
|
||||||
if filtered:
|
if filtered:
|
||||||
logger.debug(f"从去重后实体提取到别名(兜底): {filtered}")
|
logger.debug(f"从去重后实体提取到别名(兜底): {filtered}")
|
||||||
@@ -1531,28 +1531,15 @@ class ExtractionOrchestrator:
|
|||||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||||
return []
|
return []
|
||||||
# 过滤掉占位名称,防止历史脏数据传播
|
# 过滤掉占位名称,防止历史脏数据传播
|
||||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
filtered = [a for a in aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
|
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
|
||||||
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
|
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
|
||||||
cypher = """
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||||
MATCH (e:ExtractedEntity)
|
fetch_neo4j_assistant_aliases,
|
||||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['AI助手', '助手', 'AI Assistant', 'Assistant']
|
)
|
||||||
RETURN e.aliases AS aliases
|
return await fetch_neo4j_assistant_aliases(Neo4jConnector(), end_user_id)
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
|
|
||||||
assistant_aliases = set()
|
|
||||||
for record in (result or []):
|
|
||||||
for alias in (record.get('aliases') or []):
|
|
||||||
assistant_aliases.add(alias.strip().lower())
|
|
||||||
if assistant_aliases:
|
|
||||||
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
|
||||||
return assistant_aliases
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
|
||||||
return set()
|
|
||||||
|
|
||||||
def _resolve_other_name(
|
def _resolve_other_name(
|
||||||
self,
|
self,
|
||||||
@@ -1571,16 +1558,16 @@ class ExtractionOrchestrator:
|
|||||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||||
"""
|
"""
|
||||||
# 当前值为空或为占位名称时,需要更新
|
# 当前值为空或为占位名称时,需要更新
|
||||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||||
candidate = current_aliases[0].strip() if current_aliases else None
|
candidate = current_aliases[0].strip() if current_aliases else None
|
||||||
# 确保候选值不是占位名称
|
# 确保候选值不是占位名称
|
||||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||||
return None
|
return None
|
||||||
return candidate
|
return candidate
|
||||||
if current not in neo4j_aliases:
|
if current not in neo4j_aliases:
|
||||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||||
# 确保候选值不是占位名称
|
# 确保候选值不是占位名称
|
||||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||||
return None
|
return None
|
||||||
return candidate
|
return candidate
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user