快速检索,需要在接口部分添加LLM整合
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
@@ -13,10 +14,11 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""问题处理节点服务类"""
|
"""问题处理节点服务类"""
|
||||||
|
|
||||||
@@ -24,9 +26,11 @@ class ProblemNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""问题分解节点"""
|
"""问题分解节点"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
@@ -35,11 +39,16 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='problem_breakdown_prompt.jinja2',
|
template_name='problem_breakdown_prompt.jinja2',
|
||||||
operation_name='split_the_problem',
|
operation_name='split_the_problem',
|
||||||
history=history,
|
history=history,
|
||||||
sentence=content
|
sentence=content,
|
||||||
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +80,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
split_result_dict = []
|
split_result_dict = []
|
||||||
for index, item in enumerate(json.loads(split_result)):
|
for index, item in enumerate(json.loads(split_result)):
|
||||||
split_data = {
|
split_data = {
|
||||||
"id": f"Q{index+1}",
|
"id": f"Q{index + 1}",
|
||||||
"question": item['extended_question'],
|
"question": item['extended_question'],
|
||||||
"type": item['type'],
|
"type": item['type'],
|
||||||
"reason": item['reason']
|
"reason": item['reason']
|
||||||
@@ -124,6 +133,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
# 返回更新后的状态,包含spit_context字段
|
# 返回更新后的状态,包含spit_context字段
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""问题扩展节点"""
|
"""问题扩展节点"""
|
||||||
# 获取原始数据和分解结果
|
# 获取原始数据和分解结果
|
||||||
@@ -147,11 +157,16 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
data = []
|
data = []
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='Problem_Extension_prompt.jinja2',
|
template_name='Problem_Extension_prompt.jinja2',
|
||||||
operation_name='problem_extension',
|
operation_name='problem_extension',
|
||||||
history=history,
|
history=history,
|
||||||
questions=databasets
|
questions=databasets,
|
||||||
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -232,6 +247,3 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {"problem_extension": result}
|
return {"problem_extension": result}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user