Fix/memory bug fix (#161)
* 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -13,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -35,11 +36,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=content
|
||||
sentence=content,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,11 +153,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=databasets
|
||||
questions=databasets,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
@@ -19,7 +20,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -26,60 +26,130 @@ class VerificationNodeService(LLMServiceMixin):
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
async def Verify_prompt(state: ReadState,messages_deal):
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
if hasattr(item, 'model_dump'):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": messages_deal.expansion_issue,
|
||||
"verified_data": verified_data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.split_result,
|
||||
"reason": messages_deal.reason,
|
||||
"query": data,
|
||||
"verified_count": len(messages_deal.expansion_issue),
|
||||
"reason": messages_deal.reason or "验证完成",
|
||||
"query": messages_deal.query,
|
||||
"verified_count": len(verified_data),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
async def Verify(state: ReadState):
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", '')
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve
|
||||
}
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
history=history,
|
||||
sentence=messages
|
||||
)
|
||||
|
||||
# 使用优化的LLM服务
|
||||
structured = await verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"split_result": "fail",
|
||||
"expansion_issue": [],
|
||||
"reason": "验证失败"
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
)
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
return {"verify": result}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
history=history,
|
||||
sentence=messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
structured = VerificationResult(
|
||||
query=content,
|
||||
history=history if isinstance(history, list) else [],
|
||||
expansion_issue=[],
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
return {
|
||||
"verify": {
|
||||
"status": "failed",
|
||||
"verified_data": [],
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": "failed",
|
||||
"reason": f"验证过程出错: {str(e)}",
|
||||
"query": state.get('data', ''),
|
||||
"verified_count": 0,
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,7 @@ async def make_read_graph():
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
@@ -45,18 +45,17 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
loop_count = counter.get_total()
|
||||
print(status)
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status:
|
||||
if loop_count < 2: # Maximum loop count is 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
# else:
|
||||
# # Add default return value to avoid returning None
|
||||
# counter.reset()
|
||||
# return "Summary" # Default based on business requirements
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
# counter.reset()
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
@@ -4,11 +4,29 @@ from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationItem(BaseModel):
|
||||
"""Individual verification item for a query-answer pair."""
|
||||
|
||||
query_small: str = Field(..., description="子问题")
|
||||
answer_small: str = Field(..., description="子问题的回答")
|
||||
status: str = Field(..., description="验证状态:True 或 False")
|
||||
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str
|
||||
expansion_issue: List[Dict[str, Any]]
|
||||
split_result: str
|
||||
reason: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
query: str = Field(..., description="原始查询问题")
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
|
||||
expansion_issue: List[VerificationItem] = Field(
|
||||
default_factory=list,
|
||||
description="验证后的数据列表,包含所有通过验证的问答对"
|
||||
)
|
||||
split_result: str = Field(
|
||||
...,
|
||||
description="验证结果状态:success(expansion_issue 非空)或 failed(expansion_issue 为空)"
|
||||
)
|
||||
reason: Optional[str] = Field(
|
||||
None,
|
||||
description="验证结果的说明和分析"
|
||||
)
|
||||
|
||||
@@ -162,7 +162,7 @@ class OptimizedLLMService:
|
||||
return fallback_value
|
||||
elif isinstance(fallback_value, dict):
|
||||
return response_model(**fallback_value)
|
||||
|
||||
|
||||
# 尝试创建空的响应模型
|
||||
if hasattr(response_model, 'root'):
|
||||
# RootModel类型
|
||||
@@ -170,7 +170,7 @@ class OptimizedLLMService:
|
||||
else:
|
||||
# 普通BaseModel类型
|
||||
return response_model()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建降级响应失败: {e}")
|
||||
# 最后的降级策略
|
||||
|
||||
@@ -42,19 +42,33 @@
|
||||
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||
### 第五步 输出格式
|
||||
按照json的形式输出
|
||||
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
||||
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
||||
{"query_small": query_small,
|
||||
"answer_small": answer_small,,
|
||||
"status": 回答的结果是否符合query_small,填写状态,
|
||||
"query_answer": answer_small},
|
||||
{"query":"原来Query的字段",
|
||||
"history":"原来的history字段",
|
||||
"expansion_issue":以列表的形式存储验证之后的数据比如[
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "True",
|
||||
"query_answer": "张曼 婷喜欢绘画。"
|
||||
},{}......]
|
||||
,
|
||||
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
||||
"reason": 为以上分析完之后的结果给一个说明
|
||||
}
|
||||
"query_small": "子问题",
|
||||
"answer_small": "子问题的回答",
|
||||
"status": "True或False,表示回答是否符合query_small",
|
||||
"query_answer": "问题的答案(与answer_small相同)"
|
||||
},
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "False",
|
||||
"query_answer": "张曼婷喜欢绘画。"
|
||||
}
|
||||
],
|
||||
"split_result":"如果expansion_issue是空的列表返回failed,不是空列表返回success",
|
||||
"reason": "为以上分析完之后的结果给一个说明"
|
||||
}
|
||||
|
||||
**输出格式要求**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. The output language should always be the same as the input language
|
||||
|
||||
**JSON Schema:**
|
||||
{{ json_schema }}
|
||||
Reference in New Issue
Block a user