Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
lixinyue
2026-01-20 20:09:14 +08:00
8 changed files with 191 additions and 77 deletions

View File

@@ -1,3 +1,4 @@
import os
import json
import time
from app.core.logging_config import get_agent_logger
@@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=content
sentence=content,
json_schema=json_schema
)
try:
@@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=databasets
questions=databasets,
json_schema=json_schema
)
try:

View File

@@ -1,5 +1,6 @@
import os
import time
from app.core.logging_config import get_agent_logger, log_time
@@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
db_session = next(get_db())

View File

@@ -1,4 +1,4 @@
import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -26,60 +26,130 @@ class VerificationNodeService(LLMServiceMixin):
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState,messages_deal):
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# 将 VerificationItem 对象转换为字典列表
verified_data = []
if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue:
if hasattr(item, 'model_dump'):
verified_data.append(item.model_dump())
elif isinstance(item, dict):
verified_data.append(item)
Verify_result = {
"status": messages_deal.split_result,
"verified_data": messages_deal.expansion_issue,
"verified_data": verified_data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.split_result,
"reason": messages_deal.reason,
"query": data,
"verified_count": len(messages_deal.expansion_issue),
"reason": messages_deal.reason or "验证完成",
"query": messages_deal.query,
"verified_count": len(verified_data),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return Verify_result
async def Verify(state: ReadState):
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(group_id, group_id, group_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", '')
retrieve = retrieve.get("Expansion_issue", [])
messages = {
"Query": content,
"Expansion_issue": retrieve
}
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages
)
# 使用优化的LLM服务
structured = await verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"split_result": "fail",
"expansion_issue": [],
"reason": "验证失败"
retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
)
result = await Verify_prompt(state, structured)
return {"verify": result}
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages,
json_schema=json_schema
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
structured = VerificationResult(
query=content,
history=history if isinstance(history, list) else [],
expansion_issue=[],
split_result="failed",
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
return {
"verify": {
"status": "failed",
"verified_data": [],
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": "failed",
"reason": f"验证过程出错: {str(e)}",
"query": state.get('data', ''),
"verified_count": 0,
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}

View File

@@ -59,6 +59,7 @@ async def make_read_graph():
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)

View File

@@ -45,18 +45,17 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status']
loop_count = counter.get_total()
print(status)
# loop_count = counter.get_total()
if "success" in status:
counter.reset()
# counter.reset()
return "Summary"
elif "failed" in status:
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
# else:
# # Add default return value to avoid returning None
# counter.reset()
# return "Summary" # Default based on business requirements
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
# counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
# counter.reset()
return "Summary" # Default based on business requirements

View File

@@ -4,11 +4,29 @@ from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationItem(BaseModel):
"""Individual verification item for a query-answer pair."""
query_small: str = Field(..., description="子问题")
answer_small: str = Field(..., description="子问题的回答")
status: str = Field(..., description="验证状态True 或 False")
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)
query: str = Field(..., description="原始查询问题")
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
expansion_issue: List[VerificationItem] = Field(
default_factory=list,
description="验证后的数据列表,包含所有通过验证的问答对"
)
split_result: str = Field(
...,
description="验证结果状态successexpansion_issue 非空)或 failedexpansion_issue 为空)"
)
reason: Optional[str] = Field(
None,
description="验证结果的说明和分析"
)

View File

@@ -162,7 +162,7 @@ class OptimizedLLMService:
return fallback_value
elif isinstance(fallback_value, dict):
return response_model(**fallback_value)
# 尝试创建空的响应模型
if hasattr(response_model, 'root'):
# RootModel类型
@@ -170,7 +170,7 @@ class OptimizedLLMService:
else:
# 普通BaseModel类型
return response_model()
except Exception as e:
logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略

View File

@@ -42,19 +42,33 @@
如果状态是TRUE保留这条数据否则需不需要这条数据
### 第五步 输出格式
按照json的形式输出
{"data":"Query":原来Query的字段"history":原来的history字段
"expansion_issue":以为列表的形式存储验证之后的数据比如[
{"query_small": query_small,
"answer_small": answer_small,,
"status": 回答的结果是否符合query_small填写状态,
"query_answer": answer_small},
{"query":"原来Query的字段",
"history":"原来的history字段",
"expansion_issue":以列表的形式存储验证之后的数据比如[
{
"query_small": "张曼婷生日是什么时候?",
"answer_small": "张曼婷喜欢绘画。",
"status": "True",
"query_answer": "张曼 婷喜欢绘画。"
},{}......]
,
"split_result":如果expansion_issue是空的列表返回failed不是空列表返回success,
"reason": 为以上分析完之后的结果给一个说明
}
"query_small": "子问题",
"answer_small": "子问题的回答",
"status": "True或False表示回答是否符合query_small",
"query_answer": "问题的答案与answer_small相同"
},
{
"query_small": "张曼婷生日是什么时候?",
"answer_small": "张曼婷喜欢绘画。",
"status": "False",
"query_answer": "张曼婷喜欢绘画。"
}
],
"split_result":"如果expansion_issue是空的列表返回failed不是空列表返回success",
"reason": "为以上分析完之后的结果给一个说明"
}
**输出格式要求**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
5. The output language should always be the same as the input language
**JSON Schema:**
{{ json_schema }}