[changes] Extract the unified auxiliary function; downgrade the log; initialize the variables
This commit is contained in:
@@ -409,59 +409,24 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
reranked = results_dict.get('reranked_results', {})
|
reranked = results_dict.get('reranked_results', {})
|
||||||
community_hits = reranked.get('communities', [])
|
community_hits = reranked.get('communities', [])
|
||||||
if not community_hits:
|
if not community_hits:
|
||||||
# 兼容非 hybrid 路径
|
|
||||||
community_hits = results_dict.get('communities', [])
|
community_hits = results_dict.get('communities', [])
|
||||||
community_ids = [c.get('id') for c in community_hits if c.get('id')]
|
if community_hits:
|
||||||
if community_ids:
|
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
community_results=community_hits,
|
||||||
expand_connector = Neo4jConnector()
|
end_user_id=end_user_id,
|
||||||
try:
|
existing_content=clean_content,
|
||||||
expand_result = await search_graph_community_expand(
|
)
|
||||||
connector=expand_connector,
|
if new_texts:
|
||||||
community_ids=community_ids,
|
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
expanded_stmts = expand_result.get('expanded_statements', [])
|
|
||||||
if expanded_stmts:
|
|
||||||
# 去重:过滤掉直接检索已经包含的 statement 文本
|
|
||||||
existing_lines = set(clean_content.splitlines())
|
|
||||||
expanded_texts = [
|
|
||||||
s['statement'] for s in expanded_stmts
|
|
||||||
if s.get('statement') and s['statement'] not in existing_lines
|
|
||||||
]
|
|
||||||
if expanded_texts:
|
|
||||||
clean_content = clean_content + '\n' + '\n'.join(expanded_texts)
|
|
||||||
# 暂存展开结果,稍后写回 raw_results['results']
|
|
||||||
_expanded_stmts_to_write = expanded_stmts
|
|
||||||
logger.info(
|
|
||||||
f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements,"
|
|
||||||
f"community_ids={community_ids}"
|
|
||||||
)
|
|
||||||
except Exception as expand_err:
|
|
||||||
logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}")
|
|
||||||
finally:
|
|
||||||
await expand_connector.close()
|
|
||||||
except Exception as parse_err:
|
except Exception as parse_err:
|
||||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_results = raw_results['results']
|
raw_results = raw_results['results']
|
||||||
# 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段
|
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||||
_fields_to_remove = {
|
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
|
||||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
|
||||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
|
||||||
}
|
|
||||||
def _clean(obj):
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove}
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [_clean(i) for i in obj]
|
|
||||||
return obj
|
|
||||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raw_results = []
|
raw_results = []
|
||||||
|
|
||||||
|
|||||||
@@ -348,8 +348,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
if raw_results and isinstance(raw_results, dict):
|
if raw_results and isinstance(raw_results, dict):
|
||||||
reranked = raw_results.get('reranked_results', {})
|
reranked = raw_results.get('reranked_results', {})
|
||||||
community_hits = reranked.get('communities', [])
|
community_hits = reranked.get('communities', [])
|
||||||
logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
|
|||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||||
|
_EXPAND_FIELDS_TO_REMOVE = {
|
||||||
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
|
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||||
|
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_expand_fields(obj):
|
||||||
|
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_clean_expand_fields(i) for i in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def expand_communities_to_statements(
|
||||||
|
community_results: List[dict],
|
||||||
|
end_user_id: str,
|
||||||
|
existing_content: str = "",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Tuple[List[dict], List[str]]:
|
||||||
|
"""
|
||||||
|
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||||
|
|
||||||
|
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||||
|
- 过滤不可序列化字段
|
||||||
|
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||||
|
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||||
|
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||||
|
"""
|
||||||
|
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||||
|
if not community_ids or not end_user_id:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
result = await search_graph_community_expand(
|
||||||
|
connector=connector,
|
||||||
|
community_ids=community_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||||
|
return [], []
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
expanded_stmts = result.get("expanded_statements", [])
|
||||||
|
if not expanded_stmts:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
existing_lines = set(existing_content.splitlines())
|
||||||
|
new_texts = [
|
||||||
|
s["statement"] for s in expanded_stmts
|
||||||
|
if s.get("statement") and s["statement"] not in existing_lines
|
||||||
|
]
|
||||||
|
cleaned = _clean_expand_fields(expanded_stmts)
|
||||||
|
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||||
|
return cleaned, new_texts
|
||||||
|
|
||||||
|
|
||||||
class SearchService:
|
class SearchService:
|
||||||
"""Service for executing hybrid search and processing results."""
|
"""Service for executing hybrid search and processing results."""
|
||||||
@@ -190,28 +256,11 @@ class SearchService:
|
|||||||
if search_type == "hybrid"
|
if search_type == "hybrid"
|
||||||
else answer.get('communities', [])
|
else answer.get('communities', [])
|
||||||
)
|
)
|
||||||
community_ids = [
|
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||||
r.get("id") for r in community_results if r.get("id")
|
community_results=community_results,
|
||||||
]
|
end_user_id=end_user_id,
|
||||||
if community_ids and end_user_id:
|
)
|
||||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
answer_list.extend(cleaned_stmts)
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
expand_connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
expand_result = await search_graph_community_expand(
|
|
||||||
connector=expand_connector,
|
|
||||||
community_ids=community_ids,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
expanded_stmts = expand_result.get("expanded_statements", [])
|
|
||||||
if expanded_stmts:
|
|
||||||
answer_list.extend(expanded_stmts)
|
|
||||||
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"社区展开检索失败,跳过: {e}")
|
|
||||||
finally:
|
|
||||||
await expand_connector.close()
|
|
||||||
|
|
||||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||||
content_list = []
|
content_list = []
|
||||||
|
|||||||
@@ -727,6 +727,11 @@ async def run_hybrid_search(
|
|||||||
keyword_task = None
|
keyword_task = None
|
||||||
embedding_task = None
|
embedding_task = None
|
||||||
|
|
||||||
|
keyword_task = None
|
||||||
|
embedding_task = None
|
||||||
|
keyword_results: Dict[str, List] = {}
|
||||||
|
embedding_results: Dict[str, List] = {}
|
||||||
|
|
||||||
if search_type in ["keyword", "hybrid"]:
|
if search_type in ["keyword", "hybrid"]:
|
||||||
# Keyword-based search
|
# Keyword-based search
|
||||||
logger.info("[PERF] Starting keyword search...")
|
logger.info("[PERF] Starting keyword search...")
|
||||||
|
|||||||
Reference in New Issue
Block a user