Fix/memory mcp2 1 (#170)
* 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * feat(celery): add comprehensive logging to worker and write task - Initialize logging system in Celery worker entry point with LoggingConfig - Add logger instance and startup message to celery_worker.py - Reorganize imports in tasks.py for better readability and consistency - Add detailed logging to write_message_task for debugging and monitoring - Log task start with group_id, config_id, and storage_type parameters - Log service execution and completion status with results - Add exception handling with error logging and stack trace capture - Log task completion time and Celery task ID for performance tracking - Improves observability and troubleshooting of async task execution * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 快速检索,需要在接口部分添加LLM整合 * 快速检索,需要在接口部分添加LLM整合 --------- Co-authored-by: Ke Sun <kesun5@illinois.edu>
This commit is contained in:
@@ -9,6 +9,8 @@ from app.db import get_db
|
|||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
@@ -291,6 +293,19 @@ async def read_server(
|
|||||||
storage_type,
|
storage_type,
|
||||||
user_rag_memory_id
|
user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
if str(user_input.search_switch) == "2":
|
||||||
|
retrieve_info = result['answer']
|
||||||
|
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||||
|
query = user_input.message
|
||||||
|
|
||||||
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
|
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
|
retrieve_info=retrieve_info,
|
||||||
|
history=history,
|
||||||
|
query=query,
|
||||||
|
config_id=config_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
|
|||||||
@@ -18,16 +18,19 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
|||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""问题处理节点服务类"""
|
"""问题处理节点服务类"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""问题分解节点"""
|
"""问题分解节点"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
@@ -36,10 +39,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='problem_breakdown_prompt.jinja2',
|
template_name='problem_breakdown_prompt.jinja2',
|
||||||
operation_name='split_the_problem',
|
operation_name='split_the_problem',
|
||||||
@@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
sentence=content,
|
sentence=content,
|
||||||
json_schema=json_schema
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
structured = await problem_service.call_llm_structured(
|
structured = await problem_service.call_llm_structured(
|
||||||
@@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
response_model=ProblemExtensionResponse,
|
response_model=ProblemExtensionResponse,
|
||||||
fallback_value=[]
|
fallback_value=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
|
|
||||||
# 验证结构化响应
|
# 验证结构化响应
|
||||||
if not structured or not hasattr(structured, 'root'):
|
if not structured or not hasattr(structured, 'root'):
|
||||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||||
@@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
[item.model_dump() for item in structured.root],
|
[item.model_dump() for item in structured.root],
|
||||||
ensure_ascii=False
|
ensure_ascii=False
|
||||||
)
|
)
|
||||||
|
|
||||||
split_result_dict = []
|
split_result_dict = []
|
||||||
for index, item in enumerate(json.loads(split_result)):
|
for index, item in enumerate(json.loads(split_result)):
|
||||||
split_data = {
|
split_data = {
|
||||||
"id": f"Q{index+1}",
|
"id": f"Q{index + 1}",
|
||||||
"question": item['extended_question'],
|
"question": item['extended_question'],
|
||||||
"type": item['type'],
|
"type": item['type'],
|
||||||
"reason": item['reason']
|
"reason": item['reason']
|
||||||
}
|
}
|
||||||
split_result_dict.append(split_data)
|
split_result_dict.append(split_data)
|
||||||
|
|
||||||
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
|
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
"original_query": content
|
"original_query": content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Split_The_Problem failed: {e}",
|
f"Split_The_Problem failed: {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提供更详细的错误信息
|
# 提供更详细的错误信息
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
@@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
"content_length": len(content),
|
"content_length": len(content),
|
||||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||||
|
|
||||||
# 创建默认的空结果
|
# 创建默认的空结果
|
||||||
result = {
|
result = {
|
||||||
"context": json.dumps([], ensure_ascii=False),
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
@@ -126,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
"error": error_details
|
"error": error_details
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 返回更新后的状态,包含spit_context字段
|
# 返回更新后的状态,包含spit_context字段
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""问题扩展节点"""
|
"""问题扩展节点"""
|
||||||
# 获取原始数据和分解结果
|
# 获取原始数据和分解结果
|
||||||
@@ -153,10 +157,10 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
data = []
|
data = []
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await problem_service.template_service.render_template(
|
system_prompt = await problem_service.template_service.render_template(
|
||||||
template_name='Problem_Extension_prompt.jinja2',
|
template_name='Problem_Extension_prompt.jinja2',
|
||||||
operation_name='problem_extension',
|
operation_name='problem_extension',
|
||||||
@@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"problem_extension": result}
|
return {"problem_extension": result}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +59,6 @@ async def make_read_graph():
|
|||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
|
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class OptimizedLLMService:
|
|||||||
return fallback_value
|
return fallback_value
|
||||||
elif isinstance(fallback_value, dict):
|
elif isinstance(fallback_value, dict):
|
||||||
return response_model(**fallback_value)
|
return response_model(**fallback_value)
|
||||||
|
|
||||||
# 尝试创建空的响应模型
|
# 尝试创建空的响应模型
|
||||||
if hasattr(response_model, 'root'):
|
if hasattr(response_model, 'root'):
|
||||||
# RootModel类型
|
# RootModel类型
|
||||||
@@ -170,7 +170,7 @@ class OptimizedLLMService:
|
|||||||
else:
|
else:
|
||||||
# 普通BaseModel类型
|
# 普通BaseModel类型
|
||||||
return response_model()
|
return response_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建降级响应失败: {e}")
|
logger.error(f"创建降级响应失败: {e}")
|
||||||
# 最后的降级策略
|
# 最后的降级策略
|
||||||
|
|||||||
@@ -683,7 +683,67 @@ class MemoryAgentService:
|
|||||||
logger.debug(f"Message type: {status}")
|
logger.debug(f"Message type: {status}")
|
||||||
return status
|
return status
|
||||||
|
|
||||||
# ==================== 新增的三个接口方法 ====================
|
async def generate_summary_from_retrieve(
|
||||||
|
self,
|
||||||
|
retrieve_info: str,
|
||||||
|
history: List[Dict],
|
||||||
|
query: str,
|
||||||
|
config_id: str,
|
||||||
|
db: Session
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
基于检索信息、历史对话和查询生成最终答案
|
||||||
|
|
||||||
|
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieve_info: 检索到的信息
|
||||||
|
history: 历史对话记录
|
||||||
|
query: 用户查询
|
||||||
|
config_id: 配置ID
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的答案文本
|
||||||
|
"""
|
||||||
|
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载配置
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入必要的模块
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
|
||||||
|
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
|
||||||
|
|
||||||
|
# 构建状态对象
|
||||||
|
state = {
|
||||||
|
"data": query,
|
||||||
|
"memory_config": memory_config
|
||||||
|
}
|
||||||
|
|
||||||
|
# 直接调用 summary_llm 函数
|
||||||
|
answer = await summary_llm(
|
||||||
|
state=state,
|
||||||
|
history=history,
|
||||||
|
retrieve_info=retrieve_info,
|
||||||
|
template_name='Retrieve_Summary_prompt.jinja2',
|
||||||
|
operation_name='retrieve_summary',
|
||||||
|
response_model=RetrieveSummaryResponse,
|
||||||
|
search_mode="1"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||||
|
return answer if answer else "信息不足,无法回答。"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||||
|
return "信息不足,无法回答。"
|
||||||
|
|
||||||
|
|
||||||
async def get_knowledge_type_stats(
|
async def get_knowledge_type_stats(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user