From af7b9ee41c448a359a5d0c22f29e35ee79195889 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Tue, 20 Jan 2026 19:14:59 +0800 Subject: [PATCH] Fix/memory bug fix (#161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 --- .../langgraph_graph/nodes/problem_nodes.py | 17 +- .../langgraph_graph/nodes/summary_nodes.py | 3 +- .../nodes/verification_nodes.py | 148 +++++++++++++----- .../agent/langgraph_graph/read_graph.py | 1 + .../agent/langgraph_graph/routing/routers.py | 23 ++- .../agent/models/verification_models.py | 28 +++- .../agent/services/optimized_llm_service.py | 4 +- .../utils/prompt/split_verify_prompt.jinja2 | 44 ++++-- 8 files changed, 191 insertions(+), 77 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 0c68a47e..e02ef62b 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -1,3 +1,4 @@ +import os import json import time from app.core.logging_config import get_agent_logger @@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) history = await SessionService(store).get_history(group_id, group_id, group_id) + + # 生成 JSON schema 以指导 LLM 输出正确格式 + json_schema = ProblemExtensionResponse.model_json_schema() + system_prompt = await problem_service.template_service.render_template( template_name='problem_breakdown_prompt.jinja2', operation_name='split_the_problem', history=history, - sentence=content + sentence=content, + json_schema=json_schema ) try: @@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState: data = [] history = await SessionService(store).get_history(group_id, group_id, group_id) + + # 生成 JSON schema 以指导 LLM 输出正确格式 + json_schema = ProblemExtensionResponse.model_json_schema() + system_prompt = await problem_service.template_service.render_template( template_name='Problem_Extension_prompt.jinja2', operation_name='problem_extension', history=history, - questions=databasets + questions=databasets, + json_schema=json_schema ) try: diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 7b727da5..0d0b57b0 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,5 +1,6 @@ +import os import time from app.core.logging_config import get_agent_logger, log_time @@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) db_session = next(get_db()) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index f3a39afb..dac7ea14 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -1,4 +1,4 @@ - +import os from app.core.logging_config import get_agent_logger from app.db import get_db @@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -26,60 +26,130 @@ class VerificationNodeService(LLMServiceMixin): # 创建全局服务实例 verification_service = VerificationNodeService() -async def Verify_prompt(state: ReadState,messages_deal): +async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): + """处理验证结果并生成输出格式""" storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') + + # 将 VerificationItem 对象转换为字典列表 + verified_data = [] + if messages_deal.expansion_issue: + for item in messages_deal.expansion_issue: + if hasattr(item, 'model_dump'): + verified_data.append(item.model_dump()) + elif isinstance(item, dict): + verified_data.append(item) + Verify_result = { "status": messages_deal.split_result, - "verified_data": messages_deal.expansion_issue, + "verified_data": verified_data, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "_intermediate": { "type": "verification", "title": "Data Verification", "result": messages_deal.split_result, - "reason": messages_deal.reason, - "query": data, - "verified_count": len(messages_deal.expansion_issue), + "reason": messages_deal.reason or "验证完成", + "query": messages_deal.query, + "verified_count": len(verified_data), "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id } } return Verify_result async def Verify(state: ReadState): - content = state.get('data', '') - group_id = state.get('group_id', '') - memory_config = state.get('memory_config', None) + logger.info("=== Verify 节点开始执行 ===") + try: + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(group_id, group_id, group_id) + logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") - retrieve = state.get("retrieve", '') - retrieve = retrieve.get("Expansion_issue", []) - messages = { - "Query": content, - "Expansion_issue": retrieve - } - - system_prompt = await verification_service.template_service.render_template( - template_name='split_verify_prompt.jinja2', - operation_name='split_verify_prompt', - history=history, - sentence=messages - ) - - # 使用优化的LLM服务 - structured = await verification_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=VerificationResult, - fallback_value={ - "split_result": "fail", - "expansion_issue": [], - "reason": "验证失败" + retrieve = state.get("retrieve", {}) + logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") + + retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] + logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") + + messages = { + "Query": content, + "Expansion_issue": retrieve_expansion } - ) - - result = await Verify_prompt(state, structured) - return {"verify": result} \ No newline at end of file + + logger.info("Verify: 开始渲染模板") + + # 生成 JSON schema 以指导 LLM 输出正确格式 + json_schema = VerificationResult.model_json_schema() + + system_prompt = await verification_service.template_service.render_template( + template_name='split_verify_prompt.jinja2', + operation_name='split_verify_prompt', + history=history, + sentence=messages, + json_schema=json_schema + ) + logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}") + + # 使用优化的LLM服务,添加超时保护 + logger.info("Verify: 开始调用 LLM") + try: + # 添加 asyncio.wait_for 超时包裹,防止无限等待 + # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) + import asyncio + structured = await asyncio.wait_for( + verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "query": content, + "history": history if isinstance(history, list) else [], + "expansion_issue": [], + "split_result": "failed", + "reason": "验证失败或超时" + } + ), + timeout=150.0 # 150秒超时 + ) + logger.info(f"Verify: LLM 调用完成,result={structured}") + except asyncio.TimeoutError: + logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值") + structured = VerificationResult( + query=content, + history=history if isinstance(history, list) else [], + expansion_issue=[], + split_result="failed", + reason="LLM调用超时" + ) + + result = await Verify_prompt(state, structured) + logger.info("=== Verify 节点执行完成 ===") + return {"verify": result} + + except Exception as e: + logger.error(f"Verify 节点执行失败: {e}", exc_info=True) + # 返回失败的验证结果 + return { + "verify": { + "status": "failed", + "verified_data": [], + "storage_type": state.get('storage_type', ''), + "user_rag_memory_id": state.get('user_rag_memory_id', ''), + "_intermediate": { + "type": "verification", + "title": "Data Verification", + "result": "failed", + "reason": f"验证过程出错: {str(e)}", + "query": state.get('data', ''), + "verified_count": 0, + "storage_type": state.get('storage_type', ''), + "user_rag_memory_id": state.get('user_rag_memory_id', '') + } + } + } \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 19011a5f..c01889a9 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -59,6 +59,7 @@ async def make_read_graph(): workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) + workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index c0b01be1..004e03b3 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -45,18 +45,17 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: return 'Retrieve_Summary' # Default based on business logic def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: status=state.get('verify', '')['status'] - loop_count = counter.get_total() - print(status) + # loop_count = counter.get_total() if "success" in status: - counter.reset() + # counter.reset() return "Summary" elif "failed" in status: - if loop_count < 2: # Maximum loop count is 3 - return "content_input" - else: - counter.reset() - return "Summary_fails" - # else: - # # Add default return value to avoid returning None - # counter.reset() - # return "Summary" # Default based on business requirements + # if loop_count < 2: # Maximum loop count is 3 + # return "content_input" + # else: + # counter.reset() + return "Summary_fails" + else: + # Add default return value to avoid returning None + # counter.reset() + return "Summary" # Default based on business requirements diff --git a/api/app/core/memory/agent/models/verification_models.py b/api/app/core/memory/agent/models/verification_models.py index bd8896b3..abdce040 100644 --- a/api/app/core/memory/agent/models/verification_models.py +++ b/api/app/core/memory/agent/models/verification_models.py @@ -4,11 +4,29 @@ from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field +class VerificationItem(BaseModel): + """Individual verification item for a query-answer pair.""" + + query_small: str = Field(..., description="子问题") + answer_small: str = Field(..., description="子问题的回答") + status: str = Field(..., description="验证状态:True 或 False") + query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)") + + class VerificationResult(BaseModel): """Result model for verification operation.""" - query: str - expansion_issue: List[Dict[str, Any]] - split_result: str - reason: Optional[str] = None - history: List[Dict[str, Any]] = Field(default_factory=list) + query: str = Field(..., description="原始查询问题") + history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录") + expansion_issue: List[VerificationItem] = Field( + default_factory=list, + description="验证后的数据列表,包含所有通过验证的问答对" + ) + split_result: str = Field( + ..., + description="验证结果状态:success(expansion_issue 非空)或 failed(expansion_issue 为空)" + ) + reason: Optional[str] = Field( + None, + description="验证结果的说明和分析" + ) diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py index 6942d421..68919c4a 100644 --- a/api/app/core/memory/agent/services/optimized_llm_service.py +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -162,7 +162,7 @@ class OptimizedLLMService: return fallback_value elif isinstance(fallback_value, dict): return response_model(**fallback_value) - + # 尝试创建空的响应模型 if hasattr(response_model, 'root'): # RootModel类型 @@ -170,7 +170,7 @@ class OptimizedLLMService: else: # 普通BaseModel类型 return response_model() - + except Exception as e: logger.error(f"创建降级响应失败: {e}") # 最后的降级策略 diff --git a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 index d6ad8cab..5d10304a 100644 --- a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 @@ -42,19 +42,33 @@ 如果状态是TRUE保留这条数据,否则需不需要这条数据 ### 第五步 输出格式 按照json的形式输出 -{"data":"Query":原来Query的字段,"history":原来的history字段, -"expansion_issue":以为列表的形式存储验证之后的数据比如[ -{"query_small": query_small, - "answer_small": answer_small,, - "status": 回答的结果是否符合query_small,填写状态, - "query_answer": answer_small}, +{"query":"原来Query的字段", +"history":"原来的history字段", +"expansion_issue":以列表的形式存储验证之后的数据比如[ { - "query_small": "张曼婷生日是什么时候?", - "answer_small": "张曼婷喜欢绘画。", - "status": "True", - "query_answer": "张曼 婷喜欢绘画。" - },{}......] -, - "split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success, - "reason": 为以上分析完之后的结果给一个说明 - } \ No newline at end of file + "query_small": "子问题", + "answer_small": "子问题的回答", + "status": "True或False,表示回答是否符合query_small", + "query_answer": "问题的答案(与answer_small相同)" +}, +{ + "query_small": "张曼婷生日是什么时候?", + "answer_small": "张曼婷喜欢绘画。", + "status": "False", + "query_answer": "张曼婷喜欢绘画。" +} +], +"split_result":"如果expansion_issue是空的列表返回failed,不是空列表返回success", +"reason": "为以上分析完之后的结果给一个说明" +} + +**输出格式要求** +**CRITICAL JSON FORMATTING REQUIREMENTS:** +1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes +2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") +3. Ensure all JSON strings are properly closed and comma-separated +4. Do not include line breaks within JSON string values +5. The output language should always be the same as the input language + +**JSON Schema:** +{{ json_schema }} \ No newline at end of file