读取接口内层嵌套BUG修复
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=content
|
||||
sentence=content,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=databasets
|
||||
questions=databasets,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
@@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ async def make_read_graph():
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
print(100*'-')
|
||||
print(status)
|
||||
print(100*'-')
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
|
||||
@@ -162,22 +162,35 @@ class OptimizedLLMService:
|
||||
return fallback_value
|
||||
elif isinstance(fallback_value, dict):
|
||||
return response_model(**fallback_value)
|
||||
elif isinstance(fallback_value, list):
|
||||
# 对于 RootModel[List[...]] 类型,直接传入列表
|
||||
if hasattr(response_model, 'model_fields') and 'root' in response_model.model_fields:
|
||||
return response_model(root=fallback_value)
|
||||
# 或者尝试直接传入(Pydantic v2 的 RootModel 支持)
|
||||
return response_model(fallback_value)
|
||||
|
||||
# 尝试创建空的响应模型
|
||||
if hasattr(response_model, 'root'):
|
||||
# RootModel类型
|
||||
# 检查是否是 RootModel 类型(通过检查 __pydantic_root_model__ 属性)
|
||||
if hasattr(response_model, '__pydantic_root_model__') and response_model.__pydantic_root_model__:
|
||||
# RootModel类型 - 传入空列表
|
||||
logger.debug(f"创建 RootModel 类型的空响应: {response_model.__name__}")
|
||||
return response_model([])
|
||||
else:
|
||||
# 普通BaseModel类型
|
||||
# 普通BaseModel类型 - 尝试无参数构造
|
||||
logger.debug(f"创建普通 BaseModel 类型的空响应: {response_model.__name__}")
|
||||
return response_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建降级响应失败: {e}")
|
||||
logger.error(f"创建降级响应失败: {e}", exc_info=True)
|
||||
# 最后的降级策略
|
||||
if hasattr(response_model, 'root'):
|
||||
return response_model([])
|
||||
else:
|
||||
return response_model()
|
||||
try:
|
||||
if hasattr(response_model, '__pydantic_root_model__') and response_model.__pydantic_root_model__:
|
||||
return response_model([])
|
||||
else:
|
||||
return response_model()
|
||||
except Exception as final_error:
|
||||
logger.error(f"最终降级策略也失败: {final_error}")
|
||||
raise
|
||||
|
||||
def clear_cache(self):
|
||||
"""清理客户端缓存"""
|
||||
|
||||
@@ -100,6 +100,41 @@ class OpenAIClient(LLMClient):
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
raise LLMClientException(f"LLM 调用失败: {e}") from e
|
||||
|
||||
async def response(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""
|
||||
简单响应接口实现(用于fallback机制)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLM 响应文本
|
||||
|
||||
Raises:
|
||||
LLMClientException: LLM 调用失败
|
||||
"""
|
||||
try:
|
||||
template = """{messages}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | self.client
|
||||
|
||||
# 添加 Langfuse 回调(如果可用)
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
|
||||
# 提取文本内容
|
||||
if hasattr(response, "content"):
|
||||
return str(response.content)
|
||||
return str(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
raise LLMClientException(f"LLM 调用失败: {e}") from e
|
||||
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
@@ -131,44 +166,206 @@ class OpenAIClient(LLMClient):
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
# 方法 1: 使用 PydanticOutputParser
|
||||
template = """{question}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# 对于 DashScope 等不支持 with_structured_output 的模型,优先使用手动JSON解析
|
||||
# 这样可以避免不必要的尝试和错误
|
||||
if self.provider: #.lower() == "dashscope"
|
||||
logger.info("DashScope 模型,直接使用手动JSON解析方法")
|
||||
try:
|
||||
# 获取原始响应,添加超时保护
|
||||
chain = prompt | self.client
|
||||
response = await asyncio.wait_for(
|
||||
chain.ainvoke({"question": question_text}, config=config),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 提取响应文本
|
||||
response_text = ""
|
||||
if hasattr(response, "content"):
|
||||
response_text = str(response.content)
|
||||
else:
|
||||
response_text = str(response)
|
||||
|
||||
logger.debug(f"LLM原始响应长度: {len(response_text)}")
|
||||
|
||||
# 尝试提取JSON内容
|
||||
json_text = response_text.strip()
|
||||
|
||||
# 如果响应包含markdown代码块,提取其中的JSON
|
||||
if "```json" in json_text:
|
||||
json_text = json_text.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in json_text:
|
||||
json_text = json_text.split("```")[1].split("```")[0].strip()
|
||||
|
||||
# 尝试修复常见的JSON格式问题
|
||||
# 1. 移除可能的BOM标记
|
||||
json_text = json_text.lstrip('\ufeff')
|
||||
|
||||
# 2. 如果JSON被截断(缺少结尾的 ] 或 }),尝试修复
|
||||
if json_text.startswith('[') and not json_text.rstrip().endswith(']'):
|
||||
logger.warning("检测到JSON数组被截断,尝试修复")
|
||||
# 找到最后一个完整的对象
|
||||
last_complete_brace = json_text.rfind('}')
|
||||
if last_complete_brace > 0:
|
||||
json_text = json_text[:last_complete_brace + 1] + ']'
|
||||
logger.info(f"修复后的JSON长度: {len(json_text)}")
|
||||
elif json_text.startswith('{') and not json_text.rstrip().endswith('}'):
|
||||
logger.warning("检测到JSON对象被截断,尝试修复")
|
||||
# 找到最后一个完整的字段
|
||||
last_complete_brace = json_text.rfind('}')
|
||||
if last_complete_brace > 0:
|
||||
json_text = json_text[:last_complete_brace + 1]
|
||||
logger.info(f"修复后的JSON长度: {len(json_text)}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
parsed_dict = json.loads(json_text)
|
||||
logger.debug(f"JSON解析成功,类型: {type(parsed_dict)}")
|
||||
|
||||
# 如果是列表,记录第一个元素的结构
|
||||
if isinstance(parsed_dict, list) and len(parsed_dict) > 0:
|
||||
logger.debug(f"第一个元素的键: {list(parsed_dict[0].keys()) if isinstance(parsed_dict[0], dict) else 'not a dict'}")
|
||||
|
||||
# 尝试字段映射转换(处理LLM返回格式不匹配的情况)
|
||||
if isinstance(parsed_dict, list):
|
||||
transformed_list = []
|
||||
for item in parsed_dict:
|
||||
if isinstance(item, dict):
|
||||
transformed_item = {}
|
||||
|
||||
# 常见的字段映射规则
|
||||
field_mappings = {
|
||||
'question': ['extended_question', 'question', 'query'],
|
||||
'original_question': ['original_question', 'original', 'source_question'],
|
||||
'extended_question': ['extended_question', 'question', 'query', 'extended'],
|
||||
'type': ['type', 'category', 'question_type'],
|
||||
'reason': ['reason', 'explanation', 'rationale'],
|
||||
'query': ['query', 'question', 'text'],
|
||||
'split_result': ['split_result', 'result', 'status'],
|
||||
'expansion_issue': ['expansion_issue', 'issues', 'expansions'],
|
||||
}
|
||||
|
||||
# 对于每个期望的字段,尝试从多个可能的源字段中获取
|
||||
for target_field, source_fields in field_mappings.items():
|
||||
for source_field in source_fields:
|
||||
if source_field in item:
|
||||
transformed_item[target_field] = item[source_field]
|
||||
break
|
||||
|
||||
# 特殊处理:如果只有 'question' 但缺少 'original_question' 和 'extended_question'
|
||||
if 'question' in item and 'original_question' not in transformed_item:
|
||||
transformed_item['original_question'] = item['question']
|
||||
if 'question' in item and 'extended_question' not in transformed_item:
|
||||
transformed_item['extended_question'] = item['question']
|
||||
|
||||
# 保留原始字段(如果没有被映射)
|
||||
for key, value in item.items():
|
||||
if key not in transformed_item:
|
||||
transformed_item[key] = value
|
||||
|
||||
transformed_list.append(transformed_item)
|
||||
else:
|
||||
transformed_list.append(item)
|
||||
|
||||
logger.info(f"字段映射完成,尝试重新验证")
|
||||
logger.debug(f"转换后的数据: {transformed_list}")
|
||||
|
||||
try:
|
||||
return response_model.model_validate(transformed_list)
|
||||
except Exception as retry_error:
|
||||
logger.error(f"字段映射后仍然验证失败: {retry_error}")
|
||||
logger.error(f"完整的LLM响应: {response_text}")
|
||||
logger.error(f"原始解析字典: {parsed_dict}")
|
||||
logger.error(f"转换后的字典: {transformed_list}")
|
||||
raise
|
||||
else:
|
||||
# 非列表类型,记录并抛出原始错误
|
||||
logger.error(f"完整的LLM响应: {response_text}")
|
||||
logger.error(f"解析后的字典: {parsed_dict}")
|
||||
raise
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f"JSON解析失败: {je}")
|
||||
logger.error(f"问题位置附近的文本: {json_text[max(0, je.pos-50):min(len(json_text), je.pos+50)]}")
|
||||
|
||||
# 尝试更激进的修复:逐行解析,找到有效的JSON部分
|
||||
logger.info("尝试逐行解析JSON")
|
||||
lines = json_text.split('\n')
|
||||
for i in range(len(lines), 0, -1):
|
||||
try:
|
||||
partial_json = '\n'.join(lines[:i])
|
||||
if partial_json.startswith('['):
|
||||
partial_json = partial_json.rstrip().rstrip(',') + ']'
|
||||
elif partial_json.startswith('{'):
|
||||
partial_json = partial_json.rstrip().rstrip(',') + '}'
|
||||
|
||||
parsed_dict = json.loads(partial_json)
|
||||
logger.info(f"成功解析部分JSON(前{i}行)")
|
||||
return response_model.model_validate(parsed_dict)
|
||||
except:
|
||||
continue
|
||||
|
||||
# 如果所有尝试都失败,抛出原始错误
|
||||
raise LLMClientException(f"JSON解析失败: {je}") from je
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LLM调用超时({self.timeout}秒)")
|
||||
raise LLMClientException(f"LLM调用超时({self.timeout}秒)")
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"手动JSON解析失败: {e}", exc_info=True)
|
||||
raise LLMClientException(f"手动JSON解析失败: {e}") from e
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 方法 1: 使用 PydanticOutputParser(适用于支持的模型)
|
||||
if PydanticOutputParser is not None:
|
||||
try:
|
||||
parser = PydanticOutputParser(pydantic_object=response_model)
|
||||
format_instructions = parser.get_format_instructions()
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
prompt_with_instructions = ChatPromptTemplate.from_template(
|
||||
"{question}\n{format_instructions}"
|
||||
)
|
||||
chain = prompt | self.client | parser
|
||||
chain = prompt_with_instructions | self.client | parser
|
||||
|
||||
parsed = await chain.ainvoke(
|
||||
{
|
||||
"question": question_text,
|
||||
"format_instructions": format_instructions,
|
||||
},
|
||||
config=config
|
||||
parsed = await asyncio.wait_for(
|
||||
chain.ainvoke(
|
||||
{
|
||||
"question": question_text,
|
||||
"format_instructions": format_instructions,
|
||||
},
|
||||
config=config
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
logger.debug(f"使用 PydanticOutputParser 解析成功")
|
||||
return parsed
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"PydanticOutputParser 调用超时({self.timeout}秒)")
|
||||
raise LLMClientException(f"LLM调用超时({self.timeout}秒)")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
|
||||
)
|
||||
|
||||
# 方法 2: 使用 LangChain 的 with_structured_output
|
||||
template = """{question}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
try:
|
||||
with_so = getattr(self.client, "with_structured_output", None)
|
||||
|
||||
if callable(with_so):
|
||||
# 方法 2: 使用 LangChain 的 with_structured_output (如果支持)
|
||||
with_so = getattr(self.client, "with_structured_output", None)
|
||||
|
||||
if callable(with_so):
|
||||
try:
|
||||
structured_chain = prompt | with_so(response_model, strict=True)
|
||||
parsed = await structured_chain.ainvoke(
|
||||
{"question": question_text},
|
||||
config=config
|
||||
parsed = await asyncio.wait_for(
|
||||
structured_chain.ainvoke(
|
||||
{"question": question_text},
|
||||
config=config
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 验证并返回结果
|
||||
@@ -181,14 +378,60 @@ class OpenAIClient(LLMClient):
|
||||
# 尝试从 JSON 解析
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}")
|
||||
raise LLMClientException(f"结构化输出失败: {e}") from e
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"with_structured_output 调用超时({self.timeout}秒)")
|
||||
raise LLMClientException(f"LLM调用超时({self.timeout}秒)")
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持 with_structured_output,使用手动JSON解析"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"with_structured_output 失败: {e},尝试手动解析")
|
||||
|
||||
# 如果所有方法都失败,抛出异常
|
||||
raise LLMClientException(
|
||||
"无法生成结构化输出,所有解析方法均失败"
|
||||
)
|
||||
# 方法 3: 手动JSON解析(fallback方法)
|
||||
logger.info("使用手动JSON解析方法(fallback)")
|
||||
try:
|
||||
# 获取原始响应
|
||||
chain = prompt | self.client
|
||||
response = await asyncio.wait_for(
|
||||
chain.ainvoke({"question": question_text}, config=config),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 提取响应文本
|
||||
response_text = ""
|
||||
if hasattr(response, "content"):
|
||||
response_text = str(response.content)
|
||||
else:
|
||||
response_text = str(response)
|
||||
|
||||
logger.debug(f"LLM原始响应: {response_text[:500]}...")
|
||||
|
||||
# 尝试提取JSON内容
|
||||
json_text = response_text.strip()
|
||||
|
||||
# 如果响应包含markdown代码块,提取其中的JSON
|
||||
if "```json" in json_text:
|
||||
json_text = json_text.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in json_text:
|
||||
json_text = json_text.split("```")[1].split("```")[0].strip()
|
||||
|
||||
# 解析JSON
|
||||
parsed_dict = json.loads(json_text)
|
||||
logger.debug(f"JSON解析成功: {parsed_dict}")
|
||||
|
||||
# 验证并创建Pydantic模型
|
||||
return response_model.model_validate(parsed_dict)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"手动JSON解析调用超时({self.timeout}秒)")
|
||||
raise LLMClientException(f"LLM调用超时({self.timeout}秒)")
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f"JSON解析失败: {je}, 原始文本: {json_text[:200]}...")
|
||||
raise LLMClientException(f"JSON解析失败: {je}") from je
|
||||
except Exception as e:
|
||||
logger.error(f"手动JSON解析失败: {e}")
|
||||
raise LLMClientException(f"手动JSON解析失败: {e}") from e
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user