From a634565296d677fe5357ed0036d1ea9d34895e82 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Tue, 20 Jan 2026 18:46:53 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=BB=E5=8F=96=E6=8E=A5=E5=8F=A3=E5=86=85?= =?UTF-8?q?=E5=B1=82=E5=B5=8C=E5=A5=97BUG=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langgraph_graph/nodes/problem_nodes.py | 17 +- .../langgraph_graph/nodes/summary_nodes.py | 3 +- .../agent/langgraph_graph/read_graph.py | 1 + .../agent/langgraph_graph/routing/routers.py | 3 + .../agent/services/optimized_llm_service.py | 29 +- .../core/memory/llm_tools/openai_client.py | 297 ++++++++++++++++-- 6 files changed, 311 insertions(+), 39 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 0c68a47e..e02ef62b 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -1,3 +1,4 @@ +import os import json import time from app.core.logging_config import get_agent_logger @@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) history = await SessionService(store).get_history(group_id, group_id, group_id) + + # 生成 JSON schema 以指导 LLM 输出正确格式 + json_schema = ProblemExtensionResponse.model_json_schema() + system_prompt = await problem_service.template_service.render_template( template_name='problem_breakdown_prompt.jinja2', operation_name='split_the_problem', history=history, - sentence=content + sentence=content, + json_schema=json_schema ) try: @@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState: data = [] history = await SessionService(store).get_history(group_id, group_id, group_id) + + # 生成 JSON schema 以指导 LLM 输出正确格式 + json_schema = ProblemExtensionResponse.model_json_schema() + system_prompt = await problem_service.template_service.render_template( template_name='Problem_Extension_prompt.jinja2', operation_name='problem_extension', history=history, - questions=databasets + questions=databasets, + json_schema=json_schema ) try: diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 7b727da5..0d0b57b0 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,5 +1,6 @@ +import os import time from app.core.logging_config import get_agent_logger, log_time @@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) db_session = next(get_db()) diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 19011a5f..c01889a9 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -59,6 +59,7 @@ async def make_read_graph(): workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) + workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index 004e03b3..151ce1c5 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -45,6 +45,9 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: return 'Retrieve_Summary' # Default based on business logic def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: status=state.get('verify', '')['status'] + print(100*'-') + print(status) + print(100*'-') # loop_count = counter.get_total() if "success" in status: # counter.reset() diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py index 6942d421..fce1cd76 100644 --- a/api/app/core/memory/agent/services/optimized_llm_service.py +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -162,22 +162,35 @@ class OptimizedLLMService: return fallback_value elif isinstance(fallback_value, dict): return response_model(**fallback_value) + elif isinstance(fallback_value, list): + # 对于 RootModel[List[...]] 类型,直接传入列表 + if hasattr(response_model, 'model_fields') and 'root' in response_model.model_fields: + return response_model(root=fallback_value) + # 或者尝试直接传入(Pydantic v2 的 RootModel 支持) + return response_model(fallback_value) # 尝试创建空的响应模型 - if hasattr(response_model, 'root'): - # RootModel类型 + # 检查是否是 RootModel 类型(通过检查 __pydantic_root_model__ 属性) + if hasattr(response_model, '__pydantic_root_model__') and response_model.__pydantic_root_model__: + # RootModel类型 - 传入空列表 + logger.debug(f"创建 RootModel 类型的空响应: {response_model.__name__}") return response_model([]) else: - # 普通BaseModel类型 + # 普通BaseModel类型 - 尝试无参数构造 + logger.debug(f"创建普通 BaseModel 类型的空响应: {response_model.__name__}") return response_model() except Exception as e: - logger.error(f"创建降级响应失败: {e}") + logger.error(f"创建降级响应失败: {e}", exc_info=True) # 最后的降级策略 - if hasattr(response_model, 'root'): - return response_model([]) - else: - return response_model() + try: + if hasattr(response_model, '__pydantic_root_model__') and response_model.__pydantic_root_model__: + return response_model([]) + else: + return response_model() + except Exception as final_error: + logger.error(f"最终降级策略也失败: {final_error}") + raise def clear_cache(self): """清理客户端缓存""" diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index dce7b495..93c0efd6 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -100,6 +100,41 @@ class OpenAIClient(LLMClient): logger.error(f"LLM 调用失败: {e}") raise LLMClientException(f"LLM 调用失败: {e}") from e + async def response(self, messages: List[Dict[str, str]], **kwargs) -> str: + """ + 简单响应接口实现(用于fallback机制) + + Args: + messages: 消息列表 + **kwargs: 额外参数 + + Returns: + LLM 响应文本 + + Raises: + LLMClientException: LLM 调用失败 + """ + try: + template = """{messages}""" + prompt = ChatPromptTemplate.from_template(template) + chain = prompt | self.client + + # 添加 Langfuse 回调(如果可用) + config = {} + if self.langfuse_handler: + config["callbacks"] = [self.langfuse_handler] + + response = await chain.ainvoke({"messages": messages}, config=config) + + # 提取文本内容 + if hasattr(response, "content"): + return str(response.content) + return str(response) + + except Exception as e: + logger.error(f"LLM 调用失败: {e}") + raise LLMClientException(f"LLM 调用失败: {e}") from e + async def response_structured( self, messages: List[Dict[str, str]], @@ -131,44 +166,206 @@ class OpenAIClient(LLMClient): if self.langfuse_handler: config["callbacks"] = [self.langfuse_handler] - # 方法 1: 使用 PydanticOutputParser + template = """{question}""" + prompt = ChatPromptTemplate.from_template(template) + + # 对于 DashScope 等不支持 with_structured_output 的模型,优先使用手动JSON解析 + # 这样可以避免不必要的尝试和错误 + if self.provider: #.lower() == "dashscope" + logger.info("DashScope 模型,直接使用手动JSON解析方法") + try: + # 获取原始响应,添加超时保护 + chain = prompt | self.client + response = await asyncio.wait_for( + chain.ainvoke({"question": question_text}, config=config), + timeout=self.timeout + ) + + # 提取响应文本 + response_text = "" + if hasattr(response, "content"): + response_text = str(response.content) + else: + response_text = str(response) + + logger.debug(f"LLM原始响应长度: {len(response_text)}") + + # 尝试提取JSON内容 + json_text = response_text.strip() + + # 如果响应包含markdown代码块,提取其中的JSON + if "```json" in json_text: + json_text = json_text.split("```json")[1].split("```")[0].strip() + elif "```" in json_text: + json_text = json_text.split("```")[1].split("```")[0].strip() + + # 尝试修复常见的JSON格式问题 + # 1. 移除可能的BOM标记 + json_text = json_text.lstrip('\ufeff') + + # 2. 如果JSON被截断(缺少结尾的 ] 或 }),尝试修复 + if json_text.startswith('[') and not json_text.rstrip().endswith(']'): + logger.warning("检测到JSON数组被截断,尝试修复") + # 找到最后一个完整的对象 + last_complete_brace = json_text.rfind('}') + if last_complete_brace > 0: + json_text = json_text[:last_complete_brace + 1] + ']' + logger.info(f"修复后的JSON长度: {len(json_text)}") + elif json_text.startswith('{') and not json_text.rstrip().endswith('}'): + logger.warning("检测到JSON对象被截断,尝试修复") + # 找到最后一个完整的字段 + last_complete_brace = json_text.rfind('}') + if last_complete_brace > 0: + json_text = json_text[:last_complete_brace + 1] + logger.info(f"修复后的JSON长度: {len(json_text)}") + + # 解析JSON + try: + parsed_dict = json.loads(json_text) + logger.debug(f"JSON解析成功,类型: {type(parsed_dict)}") + + # 如果是列表,记录第一个元素的结构 + if isinstance(parsed_dict, list) and len(parsed_dict) > 0: + logger.debug(f"第一个元素的键: {list(parsed_dict[0].keys()) if isinstance(parsed_dict[0], dict) else 'not a dict'}") + + # 尝试字段映射转换(处理LLM返回格式不匹配的情况) + if isinstance(parsed_dict, list): + transformed_list = [] + for item in parsed_dict: + if isinstance(item, dict): + transformed_item = {} + + # 常见的字段映射规则 + field_mappings = { + 'question': ['extended_question', 'question', 'query'], + 'original_question': ['original_question', 'original', 'source_question'], + 'extended_question': ['extended_question', 'question', 'query', 'extended'], + 'type': ['type', 'category', 'question_type'], + 'reason': ['reason', 'explanation', 'rationale'], + 'query': ['query', 'question', 'text'], + 'split_result': ['split_result', 'result', 'status'], + 'expansion_issue': ['expansion_issue', 'issues', 'expansions'], + } + + # 对于每个期望的字段,尝试从多个可能的源字段中获取 + for target_field, source_fields in field_mappings.items(): + for source_field in source_fields: + if source_field in item: + transformed_item[target_field] = item[source_field] + break + + # 特殊处理:如果只有 'question' 但缺少 'original_question' 和 'extended_question' + if 'question' in item and 'original_question' not in transformed_item: + transformed_item['original_question'] = item['question'] + if 'question' in item and 'extended_question' not in transformed_item: + transformed_item['extended_question'] = item['question'] + + # 保留原始字段(如果没有被映射) + for key, value in item.items(): + if key not in transformed_item: + transformed_item[key] = value + + transformed_list.append(transformed_item) + else: + transformed_list.append(item) + + logger.info(f"字段映射完成,尝试重新验证") + logger.debug(f"转换后的数据: {transformed_list}") + + try: + return response_model.model_validate(transformed_list) + except Exception as retry_error: + logger.error(f"字段映射后仍然验证失败: {retry_error}") + logger.error(f"完整的LLM响应: {response_text}") + logger.error(f"原始解析字典: {parsed_dict}") + logger.error(f"转换后的字典: {transformed_list}") + raise + else: + # 非列表类型,记录并抛出原始错误 + logger.error(f"完整的LLM响应: {response_text}") + logger.error(f"解析后的字典: {parsed_dict}") + raise + except json.JSONDecodeError as je: + logger.error(f"JSON解析失败: {je}") + logger.error(f"问题位置附近的文本: {json_text[max(0, je.pos-50):min(len(json_text), je.pos+50)]}") + + # 尝试更激进的修复:逐行解析,找到有效的JSON部分 + logger.info("尝试逐行解析JSON") + lines = json_text.split('\n') + for i in range(len(lines), 0, -1): + try: + partial_json = '\n'.join(lines[:i]) + if partial_json.startswith('['): + partial_json = partial_json.rstrip().rstrip(',') + ']' + elif partial_json.startswith('{'): + partial_json = partial_json.rstrip().rstrip(',') + '}' + + parsed_dict = json.loads(partial_json) + logger.info(f"成功解析部分JSON(前{i}行)") + return response_model.model_validate(parsed_dict) + except: + continue + + # 如果所有尝试都失败,抛出原始错误 + raise LLMClientException(f"JSON解析失败: {je}") from je + + except asyncio.TimeoutError: + logger.error(f"LLM调用超时({self.timeout}秒)") + raise LLMClientException(f"LLM调用超时({self.timeout}秒)") + except LLMClientException: + raise + except Exception as e: + logger.error(f"手动JSON解析失败: {e}", exc_info=True) + raise LLMClientException(f"手动JSON解析失败: {e}") from e + + + + + + # 方法 1: 使用 PydanticOutputParser(适用于支持的模型) if PydanticOutputParser is not None: try: parser = PydanticOutputParser(pydantic_object=response_model) format_instructions = parser.get_format_instructions() - prompt = ChatPromptTemplate.from_template( + prompt_with_instructions = ChatPromptTemplate.from_template( "{question}\n{format_instructions}" ) - chain = prompt | self.client | parser + chain = prompt_with_instructions | self.client | parser - parsed = await chain.ainvoke( - { - "question": question_text, - "format_instructions": format_instructions, - }, - config=config + parsed = await asyncio.wait_for( + chain.ainvoke( + { + "question": question_text, + "format_instructions": format_instructions, + }, + config=config + ), + timeout=self.timeout ) logger.debug(f"使用 PydanticOutputParser 解析成功") return parsed + except asyncio.TimeoutError: + logger.error(f"PydanticOutputParser 调用超时({self.timeout}秒)") + raise LLMClientException(f"LLM调用超时({self.timeout}秒)") except Exception as e: logger.warning( f"PydanticOutputParser 解析失败,尝试其他方法: {e}" ) - # 方法 2: 使用 LangChain 的 with_structured_output - template = """{question}""" - prompt = ChatPromptTemplate.from_template(template) - - try: - with_so = getattr(self.client, "with_structured_output", None) - - if callable(with_so): + # 方法 2: 使用 LangChain 的 with_structured_output (如果支持) + with_so = getattr(self.client, "with_structured_output", None) + + if callable(with_so): + try: structured_chain = prompt | with_so(response_model, strict=True) - parsed = await structured_chain.ainvoke( - {"question": question_text}, - config=config + parsed = await asyncio.wait_for( + structured_chain.ainvoke( + {"question": question_text}, + config=config + ), + timeout=self.timeout ) # 验证并返回结果 @@ -181,14 +378,60 @@ class OpenAIClient(LLMClient): # 尝试从 JSON 解析 return response_model.model_validate_json(json.dumps(parsed)) - except Exception as e: - logger.error(f"结构化输出失败: {e}") - raise LLMClientException(f"结构化输出失败: {e}") from e + except asyncio.TimeoutError: + logger.error(f"with_structured_output 调用超时({self.timeout}秒)") + raise LLMClientException(f"LLM调用超时({self.timeout}秒)") + except NotImplementedError: + logger.warning( + f"模型 {self.model_name} 不支持 with_structured_output,使用手动JSON解析" + ) + except Exception as e: + logger.warning(f"with_structured_output 失败: {e},尝试手动解析") - # 如果所有方法都失败,抛出异常 - raise LLMClientException( - "无法生成结构化输出,所有解析方法均失败" - ) + # 方法 3: 手动JSON解析(fallback方法) + logger.info("使用手动JSON解析方法(fallback)") + try: + # 获取原始响应 + chain = prompt | self.client + response = await asyncio.wait_for( + chain.ainvoke({"question": question_text}, config=config), + timeout=self.timeout + ) + + # 提取响应文本 + response_text = "" + if hasattr(response, "content"): + response_text = str(response.content) + else: + response_text = str(response) + + logger.debug(f"LLM原始响应: {response_text[:500]}...") + + # 尝试提取JSON内容 + json_text = response_text.strip() + + # 如果响应包含markdown代码块,提取其中的JSON + if "```json" in json_text: + json_text = json_text.split("```json")[1].split("```")[0].strip() + elif "```" in json_text: + json_text = json_text.split("```")[1].split("```")[0].strip() + + # 解析JSON + parsed_dict = json.loads(json_text) + logger.debug(f"JSON解析成功: {parsed_dict}") + + # 验证并创建Pydantic模型 + return response_model.model_validate(parsed_dict) + + except asyncio.TimeoutError: + logger.error(f"手动JSON解析调用超时({self.timeout}秒)") + raise LLMClientException(f"LLM调用超时({self.timeout}秒)") + except json.JSONDecodeError as je: + logger.error(f"JSON解析失败: {je}, 原始文本: {json_text[:200]}...") + raise LLMClientException(f"JSON解析失败: {je}") from je + except Exception as e: + logger.error(f"手动JSON解析失败: {e}") + raise LLMClientException(f"手动JSON解析失败: {e}") from e except LLMClientException: raise