快速检索,需要在接口部分添加LLM整合

This commit is contained in:
lixinyue
2026-01-21 18:16:49 +08:00
parent 6e1f6d886d
commit bd5b97e69b

View File

@@ -1,3 +1,4 @@
import os
import json import json
import time import time
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
@@ -13,20 +14,23 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = PROJECT_ROOT_ + '/agent/utils/prompt' template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin): class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类""" """问题处理节点服务类"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
problem_service = ProblemNodeService() problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState: async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """问题分解节点"""
# 从状态中获取数据 # 从状态中获取数据
@@ -35,13 +39,18 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template( system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2', template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem', operation_name='split_the_problem',
history=history, history=history,
sentence=content sentence=content,
json_schema=json_schema
) )
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
structured = await problem_service.call_llm_structured( structured = await problem_service.call_llm_structured(
@@ -51,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
response_model=ProblemExtensionResponse, response_model=ProblemExtensionResponse,
fallback_value=[] fallback_value=[]
) )
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应 # 验证结构化响应
if not structured or not hasattr(structured, 'root'): if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
@@ -67,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
[item.model_dump() for item in structured.root], [item.model_dump() for item in structured.root],
ensure_ascii=False ensure_ascii=False
) )
split_result_dict = [] split_result_dict = []
for index, item in enumerate(json.loads(split_result)): for index, item in enumerate(json.loads(split_result)):
split_data = { split_data = {
"id": f"Q{index+1}", "id": f"Q{index + 1}",
"question": item['extended_question'], "question": item['extended_question'],
"type": item['type'], "type": item['type'],
"reason": item['reason'] "reason": item['reason']
} }
split_result_dict.append(split_data) split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = { result = {
@@ -90,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"original_query": content "original_query": content
} }
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Split_The_Problem failed: {e}", f"Split_The_Problem failed: {e}",
exc_info=True exc_info=True
) )
# 提供更详细的错误信息 # 提供更详细的错误信息
error_details = { error_details = {
"error_type": type(e).__name__, "error_type": type(e).__name__,
@@ -104,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"content_length": len(content), "content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None "llm_model_id": memory_config.llm_model_id if memory_config else None
} }
logger.error(f"Split_The_Problem error details: {error_details}") logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果 # 创建默认的空结果
result = { result = {
"context": json.dumps([], ensure_ascii=False), "context": json.dumps([], ensure_ascii=False),
@@ -120,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error": error_details "error": error_details
} }
} }
# 返回更新后的状态包含spit_context字段 # 返回更新后的状态包含spit_context字段
return {"spit_data": result} return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState: async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点""" """问题扩展节点"""
# 获取原始数据和分解结果 # 获取原始数据和分解结果
@@ -147,11 +157,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
data = [] data = []
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(group_id, group_id, group_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template( system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2', template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension', operation_name='problem_extension',
history=history, history=history,
questions=databasets questions=databasets,
json_schema=json_schema
) )
try: try:
@@ -231,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState:
} }
} }
return {"problem_extension": result} return {"problem_extension": result}