From 6718553bf4f4ceb105734a98a8b69645e00a34ff Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:47:08 +0800 Subject: [PATCH] Fix/develop memory rag (#419) * fix_rag/fast summary * fix_rag/fast summary --- .../langgraph_graph/nodes/summary_nodes.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0144c0e9..cf832add 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -17,6 +17,8 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.rag.nlp.search import knowledge_retrieval + from app.db import get_db template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') @@ -32,6 +34,41 @@ class SummaryNodeService(LLMServiceMixin): # 创建全局服务实例 summary_service = SummaryNodeService() +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config +async def rag_knowledge(state,question): + kb_config = await rag_config(state) + end_user_id = state.get('end_user_id', '') + user_rag_memory_id=state.get("user_rag_memory_id",'') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception : + retrieval_knowledge=[] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge,clean_content,cleaned_query,raw_results async def summary_history(state: ReadState) -> ReadState: end_user_id = state.get("end_user_id", '') @@ -71,7 +108,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) # 验证结构化响应 if structured is None: - logger.warning(f"LLM返回None,使用默认回答") + logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" # 根据操作类型提取答案 @@ -82,7 +119,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if hasattr(structured, 'data') and structured.data: aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" else: - logger.warning(f"结构化响应缺少data字段") + logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" # 验证答案不为空 @@ -186,12 +223,13 @@ async def Input_Summary(state: ReadState) -> ReadState: } try: - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + if storage_type!="rag": + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + else: + retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) retrieve_info, question, raw_results = "", data, [] - - try: # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # 'input_summary',RetrieveSummaryResponse) @@ -290,7 +328,6 @@ async def Summary(state: ReadState)-> ReadState: summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary":summary} - async def Summary_fails(state: ReadState)-> ReadState: storage_type=state.get("storage_type", '') user_rag_memory_id=state.get("user_rag_memory_id", '')