Merge remote-tracking branch 'origin/develop' into develop
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
@@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='problem_breakdown_prompt.jinja2',
|
template_name='problem_breakdown_prompt.jinja2',
|
||||||
operation_name='split_the_problem',
|
operation_name='split_the_problem',
|
||||||
history=history,
|
history=history,
|
||||||
sentence=content
|
sentence=content,
|
||||||
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
data = []
|
data = []
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='Problem_Extension_prompt.jinja2',
|
template_name='Problem_Extension_prompt.jinja2',
|
||||||
operation_name='problem_extension',
|
operation_name='problem_extension',
|
||||||
history=history,
|
history=history,
|
||||||
questions=databasets
|
questions=databasets,
|
||||||
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
@@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
|
import os
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -26,60 +26,130 @@ class VerificationNodeService(LLMServiceMixin):
|
|||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState,messages_deal):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
|
"""处理验证结果并生成输出格式"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
data = state.get('data', '')
|
data = state.get('data', '')
|
||||||
|
|
||||||
|
# 将 VerificationItem 对象转换为字典列表
|
||||||
|
verified_data = []
|
||||||
|
if messages_deal.expansion_issue:
|
||||||
|
for item in messages_deal.expansion_issue:
|
||||||
|
if hasattr(item, 'model_dump'):
|
||||||
|
verified_data.append(item.model_dump())
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
verified_data.append(item)
|
||||||
|
|
||||||
Verify_result = {
|
Verify_result = {
|
||||||
"status": messages_deal.split_result,
|
"status": messages_deal.split_result,
|
||||||
"verified_data": messages_deal.expansion_issue,
|
"verified_data": verified_data,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id,
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
"_intermediate": {
|
"_intermediate": {
|
||||||
"type": "verification",
|
"type": "verification",
|
||||||
"title": "Data Verification",
|
"title": "Data Verification",
|
||||||
"result": messages_deal.split_result,
|
"result": messages_deal.split_result,
|
||||||
"reason": messages_deal.reason,
|
"reason": messages_deal.reason or "验证完成",
|
||||||
"query": data,
|
"query": messages_deal.query,
|
||||||
"verified_count": len(messages_deal.expansion_issue),
|
"verified_count": len(verified_data),
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Verify_result
|
return Verify_result
|
||||||
async def Verify(state: ReadState):
|
async def Verify(state: ReadState):
|
||||||
content = state.get('data', '')
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
group_id = state.get('group_id', '')
|
try:
|
||||||
memory_config = state.get('memory_config', None)
|
content = state.get('data', '')
|
||||||
|
group_id = state.get('group_id', '')
|
||||||
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", '')
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
retrieve = retrieve.get("Expansion_issue", [])
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
messages = {
|
|
||||||
"Query": content,
|
|
||||||
"Expansion_issue": retrieve
|
|
||||||
}
|
|
||||||
|
|
||||||
system_prompt = await verification_service.template_service.render_template(
|
retrieve = state.get("retrieve", {})
|
||||||
template_name='split_verify_prompt.jinja2',
|
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||||
operation_name='split_verify_prompt',
|
|
||||||
history=history,
|
|
||||||
sentence=messages
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用优化的LLM服务
|
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||||
structured = await verification_service.call_llm_structured(
|
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||||
state=state,
|
|
||||||
db_session=db_session,
|
messages = {
|
||||||
system_prompt=system_prompt,
|
"Query": content,
|
||||||
response_model=VerificationResult,
|
"Expansion_issue": retrieve_expansion
|
||||||
fallback_value={
|
|
||||||
"split_result": "fail",
|
|
||||||
"expansion_issue": [],
|
|
||||||
"reason": "验证失败"
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
result = await Verify_prompt(state, structured)
|
logger.info("Verify: 开始渲染模板")
|
||||||
return {"verify": result}
|
|
||||||
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
|
json_schema = VerificationResult.model_json_schema()
|
||||||
|
|
||||||
|
system_prompt = await verification_service.template_service.render_template(
|
||||||
|
template_name='split_verify_prompt.jinja2',
|
||||||
|
operation_name='split_verify_prompt',
|
||||||
|
history=history,
|
||||||
|
sentence=messages,
|
||||||
|
json_schema=json_schema
|
||||||
|
)
|
||||||
|
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||||
|
|
||||||
|
# 使用优化的LLM服务,添加超时保护
|
||||||
|
logger.info("Verify: 开始调用 LLM")
|
||||||
|
try:
|
||||||
|
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||||
|
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||||
|
import asyncio
|
||||||
|
structured = await asyncio.wait_for(
|
||||||
|
verification_service.call_llm_structured(
|
||||||
|
state=state,
|
||||||
|
db_session=db_session,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=VerificationResult,
|
||||||
|
fallback_value={
|
||||||
|
"query": content,
|
||||||
|
"history": history if isinstance(history, list) else [],
|
||||||
|
"expansion_issue": [],
|
||||||
|
"split_result": "failed",
|
||||||
|
"reason": "验证失败或超时"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
timeout=150.0 # 150秒超时
|
||||||
|
)
|
||||||
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||||
|
structured = VerificationResult(
|
||||||
|
query=content,
|
||||||
|
history=history if isinstance(history, list) else [],
|
||||||
|
expansion_issue=[],
|
||||||
|
split_result="failed",
|
||||||
|
reason="LLM调用超时"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Verify_prompt(state, structured)
|
||||||
|
logger.info("=== Verify 节点执行完成 ===")
|
||||||
|
return {"verify": result}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||||
|
# 返回失败的验证结果
|
||||||
|
return {
|
||||||
|
"verify": {
|
||||||
|
"status": "failed",
|
||||||
|
"verified_data": [],
|
||||||
|
"storage_type": state.get('storage_type', ''),
|
||||||
|
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "verification",
|
||||||
|
"title": "Data Verification",
|
||||||
|
"result": "failed",
|
||||||
|
"reason": f"验证过程出错: {str(e)}",
|
||||||
|
"query": state.get('data', ''),
|
||||||
|
"verified_count": 0,
|
||||||
|
"storage_type": state.get('storage_type', ''),
|
||||||
|
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,6 +59,7 @@ async def make_read_graph():
|
|||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
|
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|||||||
@@ -45,18 +45,17 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
|||||||
return 'Retrieve_Summary' # Default based on business logic
|
return 'Retrieve_Summary' # Default based on business logic
|
||||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
status=state.get('verify', '')['status']
|
status=state.get('verify', '')['status']
|
||||||
loop_count = counter.get_total()
|
# loop_count = counter.get_total()
|
||||||
print(status)
|
|
||||||
if "success" in status:
|
if "success" in status:
|
||||||
counter.reset()
|
# counter.reset()
|
||||||
return "Summary"
|
return "Summary"
|
||||||
elif "failed" in status:
|
elif "failed" in status:
|
||||||
if loop_count < 2: # Maximum loop count is 3
|
# if loop_count < 2: # Maximum loop count is 3
|
||||||
return "content_input"
|
# return "content_input"
|
||||||
else:
|
# else:
|
||||||
counter.reset()
|
# counter.reset()
|
||||||
return "Summary_fails"
|
return "Summary_fails"
|
||||||
# else:
|
else:
|
||||||
# # Add default return value to avoid returning None
|
# Add default return value to avoid returning None
|
||||||
# counter.reset()
|
# counter.reset()
|
||||||
# return "Summary" # Default based on business requirements
|
return "Summary" # Default based on business requirements
|
||||||
|
|||||||
@@ -4,11 +4,29 @@ from typing import List, Optional, Dict, Any
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationItem(BaseModel):
|
||||||
|
"""Individual verification item for a query-answer pair."""
|
||||||
|
|
||||||
|
query_small: str = Field(..., description="子问题")
|
||||||
|
answer_small: str = Field(..., description="子问题的回答")
|
||||||
|
status: str = Field(..., description="验证状态:True 或 False")
|
||||||
|
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
|
||||||
|
|
||||||
|
|
||||||
class VerificationResult(BaseModel):
|
class VerificationResult(BaseModel):
|
||||||
"""Result model for verification operation."""
|
"""Result model for verification operation."""
|
||||||
|
|
||||||
query: str
|
query: str = Field(..., description="原始查询问题")
|
||||||
expansion_issue: List[Dict[str, Any]]
|
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
|
||||||
split_result: str
|
expansion_issue: List[VerificationItem] = Field(
|
||||||
reason: Optional[str] = None
|
default_factory=list,
|
||||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
description="验证后的数据列表,包含所有通过验证的问答对"
|
||||||
|
)
|
||||||
|
split_result: str = Field(
|
||||||
|
...,
|
||||||
|
description="验证结果状态:success(expansion_issue 非空)或 failed(expansion_issue 为空)"
|
||||||
|
)
|
||||||
|
reason: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="验证结果的说明和分析"
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,19 +42,33 @@
|
|||||||
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||||
### 第五步 输出格式
|
### 第五步 输出格式
|
||||||
按照json的形式输出
|
按照json的形式输出
|
||||||
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
{"query":"原来Query的字段",
|
||||||
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
"history":"原来的history字段",
|
||||||
{"query_small": query_small,
|
"expansion_issue":以列表的形式存储验证之后的数据比如[
|
||||||
"answer_small": answer_small,,
|
|
||||||
"status": 回答的结果是否符合query_small,填写状态,
|
|
||||||
"query_answer": answer_small},
|
|
||||||
{
|
{
|
||||||
"query_small": "张曼婷生日是什么时候?",
|
"query_small": "子问题",
|
||||||
"answer_small": "张曼婷喜欢绘画。",
|
"answer_small": "子问题的回答",
|
||||||
"status": "True",
|
"status": "True或False,表示回答是否符合query_small",
|
||||||
"query_answer": "张曼 婷喜欢绘画。"
|
"query_answer": "问题的答案(与answer_small相同)"
|
||||||
},{}......]
|
},
|
||||||
,
|
{
|
||||||
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
"query_small": "张曼婷生日是什么时候?",
|
||||||
"reason": 为以上分析完之后的结果给一个说明
|
"answer_small": "张曼婷喜欢绘画。",
|
||||||
}
|
"status": "False",
|
||||||
|
"query_answer": "张曼婷喜欢绘画。"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"split_result":"如果expansion_issue是空的列表返回failed,不是空列表返回success",
|
||||||
|
"reason": "为以上分析完之后的结果给一个说明"
|
||||||
|
}
|
||||||
|
|
||||||
|
**输出格式要求**
|
||||||
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||||
|
3. Ensure all JSON strings are properly closed and comma-separated
|
||||||
|
4. Do not include line breaks within JSON string values
|
||||||
|
5. The output language should always be the same as the input language
|
||||||
|
|
||||||
|
**JSON Schema:**
|
||||||
|
{{ json_schema }}
|
||||||
Reference in New Issue
Block a user