Fix/develop memory rag (#419)
* fix_rag/fast summary * fix_rag/fast summary
This commit is contained in:
@@ -17,6 +17,8 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
@@ -32,6 +34,41 @@ class SummaryNodeService(LLMServiceMixin):
|
|||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
async def rag_config(state):
|
||||||
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
|
kb_config = {
|
||||||
|
"knowledge_bases": [
|
||||||
|
{
|
||||||
|
"kb_id": user_rag_memory_id,
|
||||||
|
"similarity_threshold": 0.7,
|
||||||
|
"vector_similarity_weight": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"retrieve_type": "participle"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"merge_strategy": "weight",
|
||||||
|
"reranker_id": os.getenv('reranker_id'),
|
||||||
|
"reranker_top_k": 10
|
||||||
|
}
|
||||||
|
return kb_config
|
||||||
|
async def rag_knowledge(state,question):
|
||||||
|
kb_config = await rag_config(state)
|
||||||
|
end_user_id = state.get('end_user_id', '')
|
||||||
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
|
try:
|
||||||
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
|
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||||
|
cleaned_query = question
|
||||||
|
raw_results = clean_content
|
||||||
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
|
except Exception :
|
||||||
|
retrieval_knowledge=[]
|
||||||
|
clean_content = ''
|
||||||
|
raw_results = ''
|
||||||
|
cleaned_query = question
|
||||||
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
|
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
@@ -71,7 +108,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
# 验证结构化响应
|
# 验证结构化响应
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning(f"LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
# 根据操作类型提取答案
|
# 根据操作类型提取答案
|
||||||
@@ -82,7 +119,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
if hasattr(structured, 'data') and structured.data:
|
if hasattr(structured, 'data') and structured.data:
|
||||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
logger.warning(f"结构化响应缺少data字段")
|
logger.warning("结构化响应缺少data字段")
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
# 验证答案不为空
|
# 验证答案不为空
|
||||||
@@ -186,12 +223,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
if storage_type!="rag":
|
||||||
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||||
|
else:
|
||||||
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||||
retrieve_info, question, raw_results = "", data, []
|
retrieve_info, question, raw_results = "", data, []
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||||
# 'input_summary',RetrieveSummaryResponse)
|
# 'input_summary',RetrieveSummaryResponse)
|
||||||
@@ -290,7 +328,6 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary":summary}
|
||||||
|
|
||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
async def Summary_fails(state: ReadState)-> ReadState:
|
||||||
storage_type=state.get("storage_type", '')
|
storage_type=state.get("storage_type", '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||||
|
|||||||
Reference in New Issue
Block a user