读取接口内层嵌套BUG修复

This commit is contained in:
lixinyue
2026-01-20 18:51:18 +08:00
parent a634565296
commit 398964c747

View File

@@ -1,4 +1,4 @@
import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -48,38 +48,90 @@ async def Verify_prompt(state: ReadState,messages_deal):
}
return Verify_result
async def Verify(state: ReadState):
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(group_id, group_id, group_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", '')
retrieve = retrieve.get("Expansion_issue", [])
messages = {
"Query": content,
"Expansion_issue": retrieve
}
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages
)
# 使用优化的LLM服务
structured = await verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"split_result": "fail",
"expansion_issue": [],
"reason": "验证失败"
retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
)
result = await Verify_prompt(state, structured)
return {"verify": result}
logger.info("Verify: 开始渲染模板")
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content, # 添加必填的 query 字段
"split_result": "fail",
"expansion_issue": [],
"reason": "验证失败"
}
),
timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
structured = VerificationResult(
query=content,
split_result="fail",
expansion_issue=[],
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
return {
"verify": {
"status": "failed",
"verified_data": [],
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": "failed",
"reason": f"验证过程出错: {str(e)}",
"query": state.get('data', ''),
"verified_count": 0,
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}