[changes] Extract the unified auxiliary function; downgrade the log; initialize the variables

This commit is contained in:
lanceyq
2026-03-17 17:28:28 +08:00
parent 43130dcbc8
commit e74a74c3fb
4 changed files with 89 additions and 70 deletions

View File

@@ -409,59 +409,24 @@ async def retrieve(state: ReadState) -> ReadState:
reranked = results_dict.get('reranked_results', {})
community_hits = reranked.get('communities', [])
if not community_hits:
# 兼容非 hybrid 路径
community_hits = results_dict.get('communities', [])
community_ids = [c.get('id') for c in community_hits if c.get('id')]
if community_ids:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
expand_result = await search_graph_community_expand(
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
expanded_stmts = expand_result.get('expanded_statements', [])
if expanded_stmts:
# 去重:过滤掉直接检索已经包含的 statement 文本
existing_lines = set(clean_content.splitlines())
expanded_texts = [
s['statement'] for s in expanded_stmts
if s.get('statement') and s['statement'] not in existing_lines
]
if expanded_texts:
clean_content = clean_content + '\n' + '\n'.join(expanded_texts)
# 暂存展开结果,稍后写回 raw_results['results']
_expanded_stmts_to_write = expanded_stmts
logger.info(
f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements"
f"community_ids={community_ids}"
)
except Exception as expand_err:
logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}")
finally:
await expand_connector.close()
if community_hits:
from app.core.memory.agent.services.search_service import expand_communities_to_statements
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
community_results=community_hits,
end_user_id=end_user_id,
existing_content=clean_content,
)
if new_texts:
clean_content = clean_content + '\n' + '\n'.join(new_texts)
except Exception as parse_err:
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
try:
raw_results = raw_results['results']
# 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段
if _expanded_stmts_to_write and isinstance(raw_results, dict):
_fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
}
def _clean(obj):
if isinstance(obj, dict):
return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove}
if isinstance(obj, list):
return [_clean(i) for i in obj]
return obj
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write)
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
except Exception:
raw_results = []

View File

@@ -348,8 +348,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
community_hits = reranked.get('communities', [])
logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:

View File

@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
_EXPAND_FIELDS_TO_REMOVE = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
}
def _clean_expand_fields(obj):
"""递归过滤展开结果中不可序列化的字段DateTime 等)。"""
if isinstance(obj, dict):
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
if isinstance(obj, list):
return [_clean_expand_fields(i) for i in obj]
return obj
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
- 过滤不可序列化字段
- 返回 (cleaned_expanded_stmts, new_texts)
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
"""
community_ids = [r.get("id") for r in community_results if r.get("id")]
if not community_ids or not end_user_id:
return [], []
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
result = await search_graph_community_expand(
connector=connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
return [], []
finally:
await connector.close()
expanded_stmts = result.get("expanded_statements", [])
if not expanded_stmts:
return [], []
existing_lines = set(existing_content.splitlines())
new_texts = [
s["statement"] for s in expanded_stmts
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
class SearchService:
"""Service for executing hybrid search and processing results."""
@@ -190,28 +256,11 @@ class SearchService:
if search_type == "hybrid"
else answer.get('communities', [])
)
community_ids = [
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
expand_result = await search_graph_community_expand(
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
finally:
await expand_connector.close()
cleaned_stmts, new_texts = await expand_communities_to_statements(
community_results=community_results,
end_user_id=end_user_id,
)
answer_list.extend(cleaned_stmts)
# Extract clean content from all results按类型传入 node_type 区分 community
content_list = []

View File

@@ -727,6 +727,11 @@ async def run_hybrid_search(
keyword_task = None
embedding_task = None
keyword_task = None
embedding_task = None
keyword_results: Dict[str, List] = {}
embedding_results: Dict[str, List] = {}
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("[PERF] Starting keyword search...")