diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index ccf93d68..e3d2bf92 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -1,28 +1,29 @@ from typing import List, Optional +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header +from sqlalchemy.orm import Session +from starlette.responses import StreamingResponse + from app.cache.memory.interest_memory import InterestMemoryCache from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService from app.core.rag.llm.cv_model import QWenCV from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User -from app.core.memory.agent.utils.session_tools import SessionService -from app.core.memory.agent.utils.redis_tool import store -from app.repositories import knowledge_repository, WorkspaceRepository +from app.repositories import knowledge_repository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse from app.services import task_service, workspace_service from app.services.memory_agent_service import MemoryAgentService from app.services.model_service import ModelConfigService -from dotenv import load_dotenv -from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header -from sqlalchemy.orm import Session -from starlette.responses import StreamingResponse load_dotenv() api_logger = get_api_logger() @@ -37,7 +38,7 @@ router = APIRouter( @router.get("/health/status", response_model=ApiResponse) async def get_health_status( - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user) ): """ Get latest health status written by Celery periodic task @@ -55,8 +56,9 @@ async def get_health_status( @router.get("/download_log") async def download_log( - log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), - current_user: User = Depends(get_current_user) + log_type: str = Query("file", regex="^(file|transmission)$", + description="日志类型: file=完整文件, transmission=实时流式传输"), + current_user: User = Depends(get_current_user) ): """ Download or stream agent service log file @@ -75,16 +77,16 @@ async def download_log( - transmission mode: StreamingResponse with SSE """ api_logger.info(f"Log download requested with log_type={log_type}") - + # Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity) if log_type not in ["file", "transmission"]: api_logger.warning(f"Invalid log_type parameter: {log_type}") return fail( - BizCode.BAD_REQUEST, - "无效的log_type参数", + BizCode.BAD_REQUEST, + "无效的log_type参数", "log_type必须是'file'或'transmission'" ) - + # Route to appropriate mode if log_type == "file": # File mode: Return complete log file content @@ -119,10 +121,10 @@ async def download_log( @router.post("/writer_service", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Write service endpoint - processes write operations synchronously @@ -136,11 +138,11 @@ async def write_server( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - + # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, @@ -149,7 +151,7 @@ async def write_server( ) if storage_type is None: storage_type = 'neo4j' user_rag_memory_id = '' - + # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id if storage_type == 'rag': if workspace_id: @@ -161,13 +163,15 @@ async def write_server( if knowledge: user_rag_memory_id = str(knowledge.id) else: - api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") + api_logger.warning( + f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") storage_type = 'neo4j' else: api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") storage_type = 'neo4j' - - api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") + + api_logger.info( + f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: messages_list = memory_agent_service.get_messages_list(user_input) result = await memory_agent_service.write_memory( @@ -175,7 +179,7 @@ async def write_server( messages_list, config_id, db, - storage_type, + storage_type, user_rag_memory_id, language ) @@ -195,10 +199,10 @@ async def write_server( @router.post("/writer_service_async", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Async write service endpoint - enqueues write processing to Celery @@ -213,10 +217,11 @@ async def write_server_async( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id - api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") + api_logger.info( + f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( @@ -244,7 +249,7 @@ async def write_server_async( args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] ) api_logger.info(f"Write task queued: {task.id}") - + return success(data={"task_id": task.id}, msg="写入任务已提交") except Exception as e: api_logger.error(f"Async write operation failed: {str(e)}") @@ -254,9 +259,9 @@ async def write_server_async( @router.post("/read_service", response_model=ApiResponse) @cur_workspace_access_guard() async def read_server( - user_input: UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Read service endpoint - processes read operations synchronously @@ -291,8 +296,9 @@ async def read_server( ) if knowledge: user_rag_memory_id = str(knowledge.id) - - api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + + api_logger.info( + f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") try: result = await memory_agent_service.read_memory( user_input.end_user_id, @@ -306,7 +312,8 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, + user_input.end_user_id) query = user_input.message # 调用 memory_agent_service 的方法生成最终答案 @@ -319,7 +326,7 @@ async def read_server( db=db ) if "信息不足,无法回答" in result['answer']: - result['answer']=retrieve_info + result['answer'] = retrieve_info return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup @@ -335,9 +342,10 @@ async def read_server( @router.post("/file", response_model=ApiResponse) async def file_update( files: List[UploadFile] = File(..., description="要上传的文件"), - model_id:str = Form(..., description="模型ID"), + model_id: str = Form(..., description="模型ID"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 文件上传接口 - 支持图片识别 @@ -350,9 +358,6 @@ async def file_update( Returns: 文件处理结果 """ - - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) api_logger.info(f"File upload requested, file count: {len(files)}") config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) apiConfig: ModelApiKey = config.api_keys[0] @@ -361,7 +366,7 @@ async def file_update( for file in files: api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}") content = await file.read() - + if file.content_type and file.content_type.startswith("image/"): vision_model = QWenCV( key=apiConfig.api_key, @@ -375,12 +380,12 @@ async def file_update( else: api_logger.warning(f"Unsupported file type: {file.content_type}") file_content.append(f"[不支持的文件类型: {file.content_type}]") - + result_text = ';'.join(file_content) api_logger.info(f"File processing completed, result length: {len(result_text)}") - + return success(data=result_text, msg="转换文本成功") - + except Exception as e: api_logger.error(f"File processing failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e)) @@ -430,8 +435,8 @@ async def read_server_async( @router.get("/read_result/", response_model=ApiResponse) async def get_read_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async read task @@ -452,7 +457,7 @@ async def get_read_task_result( try: result = task_service.get_task_memory_read_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -470,7 +475,7 @@ async def get_read_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="查询任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -479,7 +484,7 @@ async def get_read_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -499,7 +504,7 @@ async def get_read_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -507,8 +512,8 @@ async def get_read_task_result( @router.get("/write_result/", response_model=ApiResponse) async def get_write_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async write task @@ -529,7 +534,7 @@ async def get_write_task_result( try: result = task_service.get_task_memory_write_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -547,7 +552,7 @@ async def get_write_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="写入任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -556,7 +561,7 @@ async def get_write_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -576,7 +581,7 @@ async def get_write_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -584,9 +589,9 @@ async def get_write_task_result( @router.post("/status_type", response_model=ApiResponse) async def status_type( - user_input: Write_UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Determine the type of user message (read or write) @@ -629,9 +634,10 @@ async def status_type( @router.get("/stats/types", response_model=ApiResponse) async def get_knowledge_type_stats_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - only_active: bool = Query(True, description="仅统计有效记录(status=1)"), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + only_active: bool = Query(True, description="仅统计有效记录(status=1)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 @@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api( - 知识库类型根据当前用户的 current_workspace_id 过滤 - 如果用户没有当前工作空间,对应的统计返回 0 """ - api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") + api_logger.info( + f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") try: - from app.db import get_db - - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) - # 调用service层函数 result = await memory_agent_service.get_knowledge_type_stats( end_user_id=end_user_id, @@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api( current_workspace_id=current_user.current_workspace_id, db=db ) - + return success(data=result, msg="获取知识库类型统计成功") except Exception as e: api_logger.error(f"Knowledge type stats failed: {str(e)}") @@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api( @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) async def get_interest_distribution_by_user_api( - end_user_id: str = Query(..., description="用户ID(必填)"), - limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str = Query(..., description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 获取指定用户的兴趣分布标签 @@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api( @router.get("/analytics/user_profile", response_model=ApiResponse) async def get_user_profile_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取用户详情,包含: @@ -756,17 +757,17 @@ async def get_user_profile_api( # ): # """ # Get parsed API documentation (Public endpoint - no authentication required) - + # Args: # file_path: Optional path to API docs file. If None, uses default path. - + # Returns: # Parsed API documentation including title, meta info, and sections # """ # api_logger.info(f"API docs requested, file_path: {file_path or 'default'}") # try: # result = await memory_agent_service.get_api_docs(file_path) - + # if result.get("success"): # return success(msg=result["msg"], data=result["data"]) # else: @@ -782,9 +783,9 @@ async def get_user_profile_api( @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) async def get_end_user_connected_config( - end_user_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取终端用户关联的记忆配置 @@ -803,9 +804,9 @@ async def get_end_user_connected_config( from app.services.memory_agent_service import ( get_end_user_connected_config as get_config, ) - + api_logger.info(f"Getting connected config for end_user: {end_user_id}") - + try: result = get_config(end_user_id, db) return success(data=result, msg="获取终端用户关联配置成功") @@ -814,4 +815,4 @@ async def get_end_user_connected_config( return fail(BizCode.NOT_FOUND, str(e)) except Exception as e: api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) \ No newline at end of file + return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index ac1fb9a6..c8cc0460 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -1,10 +1,10 @@ -import os import json +import os import time -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - structured = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") @@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - response_content = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 1880357c..06539ad1 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -6,31 +6,26 @@ import os # ===== 第三方库 ===== from langchain.agents import create_agent from langchain_openai import ChatOpenAI + from app.core.logging_config import get_agent_logger -from app.db import get_db, get_db_context - -from app.schemas import model_schema -from app.services.memory_config_service import MemoryConfigService -from app.services.model_service import ModelConfigService - -from app.core.memory.agent.services.search_service import SearchService -from app.core.memory.agent.utils.llm_tools import ( - COUNTState, - ReadState, - deduplicate_entries, - merge_to_key_value_pairs, -) from app.core.memory.agent.langgraph_graph.tools.tool import ( create_hybrid_retrieval_tool_sync, create_time_retrieval_tool, extract_tool_message_content, ) - +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService logger = get_agent_logger(__name__) -db = next(get_db()) - async def rag_config(state): @@ -50,10 +45,12 @@ async def rag_config(state): "reranker_top_k": 10 } return kb_config -async def rag_knowledge(state,question): + + +async def rag_knowledge(state, question): kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') - user_rag_memory_id=state.get("user_rag_memory_id",'') + user_rag_memory_id = state.get("user_rag_memory_id", '') retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] @@ -61,13 +58,13 @@ async def rag_knowledge(state,question): cleaned_query = question raw_results = clean_content logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except Exception : - retrieval_knowledge=[] + except Exception: + retrieval_knowledge = [] clean_content = '' raw_results = '' cleaned_query = question logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - return retrieval_knowledge,clean_content,cleaned_query,raw_results + return retrieval_knowledge, clean_content, cleaned_query, raw_results async def llm_infomation(state: ReadState) -> ReadState: @@ -113,7 +110,7 @@ async def clean_databases(data) -> str: # 收集所有内容 content_list = [] - + # 处理重排序结果 reranked = results.get('reranked_results', {}) if reranked: @@ -141,7 +138,6 @@ async def clean_databases(data) -> str: elif isinstance(item, str): text_parts.append(item) - return '\n'.join(text_parts).strip() except Exception as e: @@ -150,23 +146,23 @@ async def clean_databases(data) -> str: async def retrieve_nodes(state: ReadState) -> ReadState: - ''' 模型信息 ''' - problem_extension=state.get('problem_extension', '')['context'] - storage_type=state.get('storage_type', '') - user_rag_memory_id=state.get('user_rag_memory_id', '') - end_user_id=state.get('end_user_id', '') + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - original=state.get('data', '') - problem_list=[] - for key,values in problem_extension.items(): + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): for data in values: problem_list.append(data) logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 async def process_question_nodes(idx, question): try: @@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: send_verify = [] for i, j in zip(keys, val, strict=False): - if j!=['']: + if j != ['']: send_verify.append({ "Query_small": i, "Answer_Small": j @@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState: } logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - return {'retrieve':dup_databases} - - + return {'retrieve': dup_databases} async def retrieve(state: ReadState) -> ReadState: # 从state中获取end_user_id import time - start=time.time() + start = time.time() problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') @@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState: with get_db_context() as db: # 使用同步数据库上下文管理器 config_service = MemoryConfigService(db) return await llm_infomation(state) + llm_config = await get_llm_info() api_key_obj = llm_config.api_keys[0] api_key = api_key_obj.api_key @@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState: ) time_retrieval_tool = create_time_retrieval_tool(end_user_id) - search_params = { "end_user_id": end_user_id, "return_raw_results": True } - hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + search_params = {"end_user_id": end_user_id, "return_raw_results": True} + hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, - tools=[time_retrieval_tool,hybrid_retrieval], + tools=[time_retrieval_tool, hybrid_retrieval], system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) @@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState: async with SEMAPHORE: # 限制并发 try: if storage_type == "rag" and user_rag_memory_id: - retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, + question) else: cleaned_query = question # 使用 asyncio 在线程池中运行同步的 agent.invoke @@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState: # json.dump(dup_databases, f, indent=4) logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") return {'retrieve': dup_databases} - - diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index cf832add..87606bf8 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,5 +1,3 @@ - - import os import time @@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.rag.nlp.search import knowledge_retrieval - -from app.db import get_db +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) -db_session = next(get_db()) + class SummaryNodeService(LLMServiceMixin): """总结节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 summary_service = SummaryNodeService() + + async def rag_config(state): user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { @@ -51,10 +51,12 @@ async def rag_config(state): "reranker_top_k": 10 } return kb_config -async def rag_knowledge(state,question): + + +async def rag_knowledge(state, question): kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') - user_rag_memory_id=state.get("user_rag_memory_id",'') + user_rag_memory_id = state.get("user_rag_memory_id", '') retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] @@ -62,25 +64,28 @@ async def rag_knowledge(state,question): cleaned_query = question raw_results = clean_content logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except Exception : - retrieval_knowledge=[] + except Exception: + retrieval_knowledge = [] clean_content = '' raw_results = '' cleaned_query = question logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - return retrieval_knowledge,clean_content,cleaned_query,raw_results + return retrieval_knowledge, clean_content, cleaned_query, raw_results + async def summary_history(state: ReadState) -> ReadState: end_user_id = state.get("end_user_id", '') history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history -async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, + search_mode) -> str: """ 增强的summary_llm函数,包含更好的错误处理和数据验证 """ data = state.get("data", '') - + # 构建系统提示词 if str(search_mode) == "0": system_prompt = await summary_service.template_service.render_template( @@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) try: # 使用优化的LLM服务进行结构化输出 - structured = await summary_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=response_model, - fallback_value=None - ) + with get_db_context() as db_session: + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) # 验证结构化响应 if structured is None: logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" - + # 根据操作类型提取答案 if operation_name == "summary": aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" @@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o else: logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" - + # 验证答案不为空 if not aimessages or aimessages.strip() == "": aimessages = "信息不足,无法回答" - + return aimessages - + except Exception as e: logger.error(f"结构化输出失败: {e}", exc_info=True) - + # 尝试非结构化输出作为fallback try: logger.info("尝试非结构化输出作为fallback") @@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o system_prompt=system_prompt, fallback_message="信息不足,无法回答" ) - + if response and response.strip(): # 简单清理响应 cleaned_response = response.strip() @@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if cleaned_response.startswith('```'): lines = cleaned_response.split('\n') cleaned_response = '\n'.join(lines[1:-1]) - + return cleaned_response else: return "信息不足,无法回答" - + except Exception as fallback_error: logger.error(f"Fallback也失败: {fallback_error}") return "信息不足,无法回答" -async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + +async def summary_redis_save(state: ReadState, aimessages) -> ReadState: data = state.get("data", '') end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( @@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState: ) await SessionService(store).cleanup_duplicates() logger.info(f"sessionid: {aimessages} 写入成功") -async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: - storage_type=state.get("storage_type",'') - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') + + +async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') input_summary = { "status": "success", "summary_result": aimessages, @@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: "user_rag_memory_id": user_rag_memory_id } } - retrieve={ + retrieve = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "_intermediate": { "type": "retrieval_summary", - "title":"快速检索", + "title": "快速检索", "summary": aimessages, "query": data, "storage_type": storage_type, @@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: } } - return input_summary,retrieve + return input_summary, retrieve + async def Input_Summary(state: ReadState) -> ReadState: - start=time.time() - storage_type=state.get("storage_type",'') + start = time.time() + storage_type = state.get("storage_type", '') memory_config = state.get('memory_config', None) - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') - end_user_id=state.get("end_user_id", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') + end_user_id = state.get("end_user_id", '') logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - history = await summary_history( state) + history = await summary_history(state) search_params = { "end_user_id": end_user_id, "question": data, @@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState: } try: - if storage_type!="rag": - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + if storage_type != "rag": + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, + memory_config=memory_config) else: retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: - logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True) retrieve_info, question, raw_results = "", data, [] try: # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', @@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState: summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary = summary_result[0] except Exception as e: - logger.error( f"Input_Summary failed: {e}", exc_info=True ) - summary= { + logger.error(f"Input_Summary failed: {e}", exc_info=True) + summary = { "status": "fail", "summary_result": "信息不足,无法回答", "storage_type": storage_type, @@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState: except Exception: duration = 0.0 log_time('检索', duration) - return {"summary":summary} + return {"summary": summary} -async def Retrieve_Summary(state: ReadState)-> ReadState: - retrieve=state.get("retrieve", '') - history = await summary_history( state) + +async def Retrieve_Summary(state: ReadState) -> ReadState: + retrieve = state.get("retrieve", '') + history = await summary_history(state) import json - with open("检索.json","w",encoding='utf-8') as f: + with open("检索.json", "w", encoding='utf-8') as f: f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) - retrieve=retrieve.get("Expansion_issue", []) - start=time.time() - retrieve_info_str=[] + retrieve = retrieve.get("Expansion_issue", []) + start = time.time() + retrieve_info_str = [] for data in retrieve: - if data=='': - retrieve_info_str='' + if data == '': + retrieve_info_str = '' else: for key, value in data.items(): - if key=='Answer_Small': + if key == 'Answer_Small': for i in value: retrieve_info_str.append(i) - retrieve_info_str=list(set(retrieve_info_str)) - retrieve_info_str='\n'.join(retrieve_info_str) + retrieve_info_str = list(set(retrieve_info_str)) + retrieve_info_str = '\n'.join(retrieve_info_str) - aimessages=await summary_llm(state,history,retrieve_info_str, - 'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + aimessages = await summary_llm(state, history, retrieve_info_str, + 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState: except Exception: duration = 0.0 log_time('Retrieval summary', duration) - + # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} + return {"summary": summary} -async def Summary(state: ReadState)-> ReadState: - start=time.time() +async def Summary(state: ReadState) -> ReadState: + start = time.time() query = state.get("data", '') - verify=state.get("verify", '') - verify_expansion_issue=verify.get("verified_data", '') - retrieve_info_str='' + verify = state.get("verify", '') + verify_expansion_issue = verify.get("verified_data", '') + retrieve_info_str = '' for data in verify_expansion_issue: for key, value in data.items(): - if key=='answer_small': + if key == 'answer_small': for i in value: - retrieve_info_str+=i+'\n' - history=await summary_history(state) + retrieve_info_str += i + '\n' + history = await summary_history(state) data = { "query": query, "history": history, "retrieve_info": retrieve_info_str } - aimessages=await summary_llm(state,history,data, - 'summary_prompt.jinja2','summary',SummaryResponse,0) + aimessages = await summary_llm(state, history, data, + 'summary_prompt.jinja2', 'summary', SummaryResponse, 0) if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) @@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState: # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} -async def Summary_fails(state: ReadState)-> ReadState: - storage_type=state.get("storage_type", '') - user_rag_memory_id=state.get("user_rag_memory_id", '') + return {"summary": summary} + + +async def Summary_fails(state: ReadState) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') history = await summary_history(state) query = state.get("data", '') verify = state.get("verify", '') @@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState: "history": history, "retrieve_info": retrieve_info_str } - aimessages = await summary_llm(state, history, data, - 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) - result= { + aimessages = await summary_llm(state, history, data, + 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) + result = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id } - return {"summary":result} \ No newline at end of file + return {"summary": result} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index b809faf2..3f7b491e 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -1,8 +1,9 @@ +import asyncio import os -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) + class VerificationNodeService(LLMServiceMixin): """验证节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 verification_service = VerificationNodeService() + async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): """处理验证结果并生成输出格式""" storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') - + # 将 VerificationItem 对象转换为字典列表 verified_data = [] if messages_deal.expansion_issue: @@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): verified_data.append(item.model_dump()) elif isinstance(item, dict): verified_data.append(item) - + Verify_result = { "status": messages_deal.split_result, "verified_data": verified_data, @@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): } } return Verify_result + + async def Verify(state: ReadState): logger.info("=== Verify 节点开始执行 ===") try: content = state.get('data', '') end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") retrieve = state.get("retrieve", {}) - logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") - + logger.info( + f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") + retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") - + messages = { "Query": content, "Expansion_issue": retrieve_expansion } logger.info("Verify: 开始渲染模板") - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = VerificationResult.model_json_schema() - + system_prompt = await verification_service.template_service.render_template( template_name='split_verify_prompt.jinja2', operation_name='split_verify_prompt', @@ -94,29 +100,30 @@ async def Verify(state: ReadState): json_schema=json_schema ) logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}") - + # 使用优化的LLM服务,添加超时保护 logger.info("Verify: 开始调用 LLM") try: # 添加 asyncio.wait_for 超时包裹,防止无限等待 # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) - import asyncio - structured = await asyncio.wait_for( - verification_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=VerificationResult, - fallback_value={ - "query": content, - "history": history if isinstance(history, list) else [], - "expansion_issue": [], - "split_result": "failed", - "reason": "验证失败或超时" - } - ), - timeout=150.0 # 150秒超时 - ) + + with get_db_context() as db_session: + structured = await asyncio.wait_for( + verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "query": content, + "history": history if isinstance(history, list) else [], + "expansion_issue": [], + "split_result": "failed", + "reason": "验证失败或超时" + } + ), + timeout=150.0 # 150秒超时 + ) logger.info(f"Verify: LLM 调用完成,result={structured}") except asyncio.TimeoutError: logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值") @@ -127,11 +134,11 @@ async def Verify(state: ReadState): split_result="failed", reason="LLM调用超时" ) - + result = await Verify_prompt(state, structured) logger.info("=== Verify 节点执行完成 ===") return {"verify": result} - + except Exception as e: logger.error(f"Verify 节点执行失败: {e}", exc_info=True) # 返回失败的验证结果 @@ -152,4 +159,4 @@ async def Verify(state: ReadState): "user_rag_memory_id": state.get('user_rag_memory_id', '') } } - } \ No newline at end of file + } diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 3476d0ec..cba1b230 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage from langgraph.constants import START, END from langgraph.graph import StateGraph - from app.db import get_db from app.services.memory_config_service import MemoryConfigService @@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import ( ) - @asynccontextmanager async def make_read_graph(): """创建并返回 LangGraph 工作流""" @@ -49,7 +47,7 @@ async def make_read_graph(): workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) workflow.add_node("Summary_fails", Summary_fails) - + # 添加边 workflow.add_edge(START, "content_input") workflow.add_conditional_edges("content_input", Split_continue) @@ -62,20 +60,20 @@ async def make_read_graph(): workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) - '''-----''' # workflow.add_edge("Retrieve", END) - + # 编译工作流 graph = workflow.compile() yield graph - + except Exception as e: print(f"创建工作流失败: {e}") raise finally: print("工作流创建完成") + async def main(): """主函数 - 运行工作流""" message = "昨天有什么好看的电影" @@ -92,17 +90,19 @@ async def main(): service_name="MemoryAgentService" ) import time - start=time.time() + start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id - ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "end_user_id": end_user_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} # 获取节点更新信息 _intermediate_outputs = [] summary = '' - + async for update_event in graph.astream( initial_state, stream_mode="updates", @@ -110,7 +110,7 @@ async def main(): ): for node_name, node_data in update_event.items(): print(f"处理节点: {node_name}") - + # 处理不同Summary节点的返回结构 if 'Summary' in node_name: if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: @@ -125,23 +125,22 @@ async def main(): spit_data = node_data.get('spit_data', {}).get('_intermediate', None) if spit_data and spit_data != [] and spit_data != {}: _intermediate_outputs.append(spit_data) - + # Problem_Extension 节点 problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) if problem_extension and problem_extension != [] and problem_extension != {}: _intermediate_outputs.append(problem_extension) - + # Retrieve 节点 retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) if retrieve_node and retrieve_node != [] and retrieve_node != {}: _intermediate_outputs.extend(retrieve_node) - + # Verify 节点 verify_n = node_data.get('verify', {}).get('_intermediate', None) if verify_n and verify_n != [] and verify_n != {}: _intermediate_outputs.append(verify_n) - # Summary 节点 summary_n = node_data.get('summary', {}).get('_intermediate', None) if summary_n and summary_n != [] and summary_n != {}: @@ -161,17 +160,20 @@ async def main(): # print(f"=== 最终摘要 ===") print(summary) - + except Exception as e: import traceback traceback.print_exc() + finally: + db_session.close() - end=time.time() - print(100*'y') - print(f"总耗时: {end-start}s") - print(100*'y') + end = time.time() + print(100 * 'y') + print(f"总耗时: {end - start}s") + print(100 * 'y') if __name__ == "__main__": import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py deleted file mode 100644 index fddd54f6..00000000 --- a/api/app/core/memory/agent/utils/llm_client_pool.py +++ /dev/null @@ -1,56 +0,0 @@ - -import asyncio -from typing import Dict, Optional -from app.core.memory.utils.llm.llm_utils import get_llm_client_fast -from app.db import get_db -from app.core.logging_config import get_agent_logger - -logger = get_agent_logger(__name__) - -class LLMClientPool: - """LLM客户端连接池""" - - def __init__(self, max_size: int = 5): - self.max_size = max_size - self.pools: Dict[str, asyncio.Queue] = {} - self.active_clients: Dict[str, int] = {} - - async def get_client(self, llm_model_id: str): - """获取LLM客户端""" - if llm_model_id not in self.pools: - self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) - self.active_clients[llm_model_id] = 0 - - pool = self.pools[llm_model_id] - - try: - # 尝试从池中获取客户端 - client = pool.get_nowait() - logger.debug(f"从池中获取LLM客户端: {llm_model_id}") - return client - except asyncio.QueueEmpty: - # 池为空,创建新客户端 - if self.active_clients[llm_model_id] < self.max_size: - db_session = next(get_db()) - client = get_llm_client_fast(llm_model_id, db_session) - self.active_clients[llm_model_id] += 1 - logger.debug(f"创建新LLM客户端: {llm_model_id}") - return client - else: - # 等待可用客户端 - logger.debug(f"等待LLM客户端可用: {llm_model_id}") - return await pool.get() - - async def return_client(self, llm_model_id: str, client): - """归还LLM客户端到池中""" - if llm_model_id in self.pools: - try: - self.pools[llm_model_id].put_nowait(client) - logger.debug(f"归还LLM客户端到池: {llm_model_id}") - except asyncio.QueueFull: - # 池已满,丢弃客户端 - self.active_clients[llm_model_id] -= 1 - logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") - -# 全局客户端池 -llm_client_pool = LLMClientPool() diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 3fbbbdbc..8959e27c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.variable.base_variable import VariableType -from app.db import get_db +from app.db import get_db_context from app.models import AppRelease from app.services.draft_run_service import AgentRunService @@ -39,7 +39,7 @@ class AgentNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} - def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]: """准备 Agent(公共逻辑) Args: @@ -57,17 +57,17 @@ class AgentNode(BaseNode): if not agent_id: raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") - db = next(get_db()) - release = db.query(AppRelease).filter( - AppRelease.id == agent_id - ).first() + with get_db_context() as db: + release = db.query(AppRelease).filter( + AppRelease.id == agent_id + ).first() if not release: raise ValueError(f"Agent 不存在: {agent_id}") - draft_service = AgentRunService(db) + - return draft_service, release, message + return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """非流式执行 @@ -79,19 +79,21 @@ class AgentNode(BaseNode): Returns: 状态更新字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") - - # 执行 Agent(非流式) - result = await draft_service.run( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ) + with get_db_context() as db: + draft_service = AgentRunService(db) + + # 执行 Agent(非流式) + result = await draft_service.run( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ) response = result.get("response", "") @@ -118,34 +120,35 @@ class AgentNode(BaseNode): Yields: 流式事件字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") # 累积完整响应 full_response = "" - + with get_db_context() as db: + draft_service = AgentRunService(db) # 执行 Agent(流式) - async for chunk in draft_service.run_stream( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ): - # 提取内容 - content = chunk.get("content", "") - full_response += content - - # 流式返回每个 chunk - yield { - "type": "chunk", - "node_id": self.node_id, - "content": content, - "full_content": full_response, - "meta_data": chunk.get("meta_data", {}) - } + async for chunk in draft_service.run_stream( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ): + # 提取内容 + content = chunk.get("content", "") + full_response += content + + # 流式返回每个 chunk + yield { + "type": "chunk", + "node_id": self.node_id, + "content": content, + "full_content": full_response, + "meta_data": chunk.get("meta_data", {}) + } logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index bb68c815..5026bf27 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -22,6 +22,7 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput @@ -103,9 +104,7 @@ def create_long_term_memory_tool( """ logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") try: - from app.db import get_db - db = next(get_db()) - try: + with get_db_context() as db: memory_content = asyncio.run( MemoryAgentService().read_memory( end_user_id=end_user_id, @@ -127,9 +126,6 @@ def create_long_term_memory_tool( logger.info(f"读取任务状态:{status}") if memory_content: memory_content = memory_content['answer'] - - finally: - db.close() logger.info(f'用户ID:Agent:{end_user_id}') logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 16aee283..f272c541 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config """ import json import os -import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional @@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import ( reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle -from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution +from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError @@ -69,7 +66,8 @@ class MemoryAgentService: logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, duration=duration, details={"message_length": len(message)}) return context else: @@ -88,8 +86,6 @@ class MemoryAgentService: raise ValueError(f"写入失败: {messages}") - - def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -271,7 +267,8 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, + db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: """ Process write operation with config_id @@ -300,7 +297,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -331,7 +329,8 @@ class MemoryAgentService: # Log failed operation if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -351,9 +350,9 @@ class MemoryAgentService: langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': langchain_messages.append(AIMessage(content=msg['content'])) - print(100*'-') + print(100 * '-') print(langchain_messages) - print(100*'-') + print(100 * '-') # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, @@ -375,29 +374,28 @@ class MemoryAgentService: contents = massages.get('write_result') # Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, + contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) - - - async def read_memory( - self, - end_user_id: str, - message: str, - history: List[Dict], - search_switch: str, - config_id: Optional[uuid.UUID]|int, - db: Session, - storage_type: str, - user_rag_memory_id: str) -> Dict: + self, + end_user_id: str, + message: str, + history: List[Dict], + search_switch: str, + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str) -> Dict: """ Process read operation with config_id @@ -425,7 +423,7 @@ class MemoryAgentService: import time start_time = time.time() - ori_message= message + ori_message = message # Resolve config_id and workspace_id # Always get workspace_id from end_user for fallback, even if config_id is provided @@ -437,7 +435,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -454,7 +453,6 @@ class MemoryAgentService: except ImportError: audit_logger = None - config_load_start = time.time() try: # Use a separate database session to avoid transaction failures @@ -562,34 +560,35 @@ class MemoryAgentService: from app.repositories.memory_short_repository import ( ShortTermMemoryRepository, ) - + retrieved_content = [] repo = ShortTermMemoryRepository(db) - + if str(search_switch) != "2": for intermediate in _intermediate_outputs: logger.debug(f"处理中间结果: {intermediate}") intermediate_type = intermediate.get('type', '') - + if intermediate_type == "search_result": query = intermediate.get('query', '') raw_results = intermediate.get('raw_results', {}) try: reranked_results = raw_results.get('reranked_results', []) - statements = [statement['statement'] for statement in reranked_results.get('statements', [])] + statements = [statement['statement'] for statement in + reranked_results.get('statements', [])] except Exception: statements = [] - + # 去重 statements = list(set(statements)) - + if query and statements: retrieved_content.append({query: statements}) - + # 如果 retrieved_content 为空,设置为空字符串 if retrieved_content == []: retrieved_content = '' - + # 只有当回答不是"信息不足"且不是快速检索时才保存 if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # 使用 upsert 方法 @@ -602,15 +601,17 @@ class MemoryAgentService: ) logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") else: - logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") - + logger.debug( + f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") + except Exception as save_error: # 保存失败不应该影响主流程,只记录错误 logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation total_time = time.time() - start_time - logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") + logger.info( + f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -641,7 +642,6 @@ class MemoryAgentService: ) raise ValueError(error_msg) - def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: """ Get standardized message list from user input. @@ -657,41 +657,43 @@ class MemoryAgentService: """ from app.core.logging_config import get_api_logger logger = get_api_logger() - + if len(user_input.messages) == 0: logger.error("Validation failed: Message list cannot be empty") raise ValueError("Message list cannot be empty") - + for idx, msg in enumerate(user_input.messages): if not isinstance(msg, dict): logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") - raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") - + raise ValueError( + f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + if 'role' not in msg: logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") - + if 'content' not in msg: logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") - raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") - + raise ValueError( + f"Message format error: Message must contain 'content' field. Error message index: {idx}") + if msg['role'] not in ['user', 'assistant']: logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") - + if not msg['content'] or not msg['content'].strip(): logger.error(f"Validation failed: Message {idx} content is empty") raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") - + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") return user_input.messages async def classify_message_type( - self, - message: str, - config_id: UUID, - db: Session, - workspace_id: Optional[UUID] = None + self, + message: str, + config_id: UUID, + db: Session, + workspace_id: Optional[UUID] = None ) -> Dict: """ Determine the type of user message (read or write) @@ -719,14 +721,15 @@ class MemoryAgentService: status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status + async def generate_summary_from_retrieve( - self, - end_user_id: str, - retrieve_info: str, - history: List[Dict], - query: str, - config_id: str, - db: Session + self, + end_user_id: str, + retrieve_info: str, + history: List[Dict], + query: str, + config_id: str, + db: Session ) -> str: """ 基于检索信息、历史对话和查询生成最终答案 @@ -761,9 +764,9 @@ class MemoryAgentService: if config_id is None: raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") # If config_id was provided, continue without workspace_id fallback - + logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") - + try: # 加载配置 config_service = MemoryConfigService(db) @@ -772,7 +775,7 @@ class MemoryAgentService: workspace_id=workspace_id, service_name="MemoryAgentService" ) - + # 导入必要的模块 from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( summary_llm, @@ -780,13 +783,13 @@ class MemoryAgentService: from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, ) - + # 构建状态对象 state = { "data": query, "memory_config": memory_config } - + # 直接调用 summary_llm 函数 answer = await summary_llm( state=state, @@ -797,21 +800,20 @@ class MemoryAgentService: response_model=RetrieveSummaryResponse, search_mode="1" ) - + logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") return answer if answer else "信息不足,无法回答。" - + except Exception as e: logger.error(f"生成摘要失败: {str(e)}", exc_info=True) return "信息不足,无法回答。" - async def get_knowledge_type_stats( - self, - end_user_id: Optional[str] = None, - only_active: bool = True, - current_workspace_id: Optional[uuid.UUID] = None, - db: Session = None + self, + db: Session, + end_user_id: Optional[str] = None, + only_active: bool = True, + current_workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """ 统计知识库类型分布,包含: @@ -837,11 +839,6 @@ class MemoryAgentService: # 1. 统计 PostgreSQL 中的知识库类型 try: - if db is None: - from app.db import get_db - db_gen = get_db() - db = next(db_gen) - # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 @@ -881,21 +878,19 @@ class MemoryAgentService: # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + - result.get("Folder", 0) + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + + result.get("Folder", 0) ) return result - - async def get_interest_distribution_by_user( - self, - end_user_id: Optional[str] = None, - limit: int = 5, - language: str = "zh" + self, + end_user_id: Optional[str] = None, + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ 获取指定用户的兴趣分布标签。 @@ -921,13 +916,12 @@ class MemoryAgentService: logger.error(f"兴趣分布标签查询失败: {e}") raise Exception(f"兴趣分布标签查询失败: {e}") - async def get_user_profile( - self, - end_user_id: Optional[str] = None, - current_user_id: Optional[str] = None, - llm_id: Optional[str] = None, - db: Session = None + self, + end_user_id: Optional[str] = None, + current_user_id: Optional[str] = None, + llm_id: Optional[str] = None, + db: Session = None ) -> Dict[str, Any]: """ 获取用户详情,包含: @@ -1017,7 +1011,8 @@ class MemoryAgentService: # 定义标签提取的结构 class UserTags(BaseModel): - tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") + tags: list[str] = Field(..., + description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") messages = [ { @@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ValueError: 当终端用户不存在或应用未发布时 """ import json as json_module - import uuid from sqlalchemy import select @@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An # 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填 memory_config_id_to_use = end_user.memory_config_id - + # 如果已有 memory_config_id,直接使用 # 如果新创建enduser,enduser.memory_config_id 必定为none # 那么使用从release中获取memory_config_id为预期行为,并且回填到 # end_user.memory_config_id if not memory_config_id_to_use: logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config") - + # 获取最新发布版本 stmt = ( select(AppRelease) @@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ) # TODO: change to current_release_id latest_release = db.scalars(stmt).first() - + if latest_release: config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") config = {} - + # 使用 MemoryConfigService 的提取方法 memory_config_service = MemoryConfigService(db) legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 验证提取的 config_id 是否存在于数据库中 from app.models.memory_config_model import MemoryConfig as MemoryConfigModel existing_config = db.get(MemoryConfigModel, legacy_config_id) - + if existing_config: memory_config_id_to_use = legacy_config_id - + # 回填到 end_user 表(lazy update) end_user.memory_config_id = memory_config_id_to_use db.commit() @@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An "workspace_id": str(app.workspace_id) } - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") + logger.info( + f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") return result @@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - + # 创建映射 - 保留 EndUser 对象引用以便回填 end_user_map = {str(eu.id): eu for eu in end_users} user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users} @@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取 users_needing_migration = [ - (end_user_id, data["app_id"]) - for end_user_id, data in user_data.items() + (end_user_id, data["app_id"]) + for end_user_id, data in user_data.items() if not data["memory_config_id"] ] - + if users_needing_migration: # 批量获取相关应用的最新发布版本 migration_app_ids = list(set(app_id for _, app_id in users_needing_migration)) - + # 查询每个应用的最新活跃发布版本 app_latest_releases = {} for app_id in migration_app_ids: @@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) latest_release = db.scalars(stmt).first() if latest_release: app_latest_releases[app_id] = latest_release - + # 为每个需要迁移的用户提取 memory_config_id config_service = MemoryConfigService(db) users_to_backfill = [] # [(end_user, memory_config_id), ...] - + for end_user_id, app_id in users_needing_migration: latest_release = app_latest_releases.get(app_id) if not latest_release: continue - + config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") continue - + # 使用 MemoryConfigService 的提取方法 app = app_map.get(app_id) if not app: continue - + legacy_config_id, is_legacy_int = config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 更新 user_data 中的 memory_config_id user_data[end_user_id]["memory_config_id"] = legacy_config_id - + # 记录需要回填的用户(稍后验证配置存在后再回填) end_user = end_user_map.get(end_user_id) if end_user: @@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) logger.info( f"Legacy int config detected for end_user {end_user_id}, will use workspace default" ) - + # 验证提取的 config_id 是否存在于数据库中 if users_to_backfill: config_ids_to_validate = list(set(cid for _, cid in users_to_backfill)) @@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) MemoryConfig.config_id.in_(config_ids_to_validate) ).all() valid_config_ids = {mc.config_id for mc in existing_configs} - + # 只回填存在的配置 valid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid in valid_config_ids ] invalid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid not in valid_config_ids ] - + if invalid_backfills: invalid_ids = [str(cid) for _, cid in invalid_backfills] logger.warning( @@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 清除 user_data 中无效的 config_id for eu, cid in invalid_backfills: user_data[str(eu.id)]["memory_config_id"] = None - + # 批量回填 end_user.memory_config_id if valid_backfills: for end_user, memory_config_id in valid_backfills: @@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id direct_config_ids = [] workspace_fallback_users = [] # [(end_user_id, workspace_id), ...] - + for end_user_id, data in user_data.items(): if data["memory_config_id"]: direct_config_ids.append(data["memory_config_id"]) @@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑) workspace_default_configs = {} unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users)) - + if unique_workspace_ids: config_service = MemoryConfigService(db) for workspace_id in unique_workspace_ids: @@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 7. 构建最终结果 for end_user_id, data in user_data.items(): memory_config = None - + # 优先使用 end_user 直接分配的配置 if data["memory_config_id"]: memory_config = config_id_to_config.get(data["memory_config_id"]) - + # 回退到工作空间默认配置 if not memory_config: workspace_id = app_to_workspace.get(data["app_id"]) @@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} logger.info(f"Successfully retrieved {len(result)} connected configs") - return result \ No newline at end of file + return result diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index 420f7ca1..b8961d33 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -1,45 +1,42 @@ # 修改 memory_konwledges_server.py 文件 -import asyncio import os -import re import uuid from pathlib import Path from typing import Optional -from pydantic import BaseModel, Field +from fastapi import HTTPException, status +from pydantic import BaseModel +from sqlalchemy.orm import Session +from app.celery_app import celery_app +from app.core.config import settings +from app.core.logging_config import get_api_logger from app.core.rag.models.chunk import DocumentChunk from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.response_utils import success -from app.db import get_db -from app.schemas import file_schema, document_schema -from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query +from app.db import get_db_context from app.models.document_model import Document -import uuid -from sqlalchemy.orm import Session -from fastapi import HTTPException, status - -from app.core.config import settings from app.models.user_model import User +from app.schemas import file_schema, document_schema from app.schemas.file_schema import CustomTextFileCreate from app.services import document_service, file_service, knowledge_service -from app.celery_app import celery_app -from app.core.logging_config import get_api_logger -from app.schemas.file_schema import CustomTextFileCreate -from app.db import get_db + # 创建一个简单的用户类用于测试 api_logger = get_api_logger() + class ChunkCreate(BaseModel): content: str + + class SimpleUser: def __init__(self, user_id: str): # 确保ID是UUID类型 self.id = user_id self.username = user_id -'''解析''' + async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): """ 解析指定文档 @@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") raise -'''获取块ID''' + async def get_document_chunks( kb_id: uuid.UUID, document_id: uuid.UUID, @@ -198,7 +195,7 @@ async def get_document_chunks( return success(data=result, msg="文档块列表查询成功") -'''查找文档ID''' + def find_document_id_by_kb_and_filename( db: Session, kb_id: str, @@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename( except Exception as e: return None -'''获取知识库ID''' + def find_documents_by_kb_id( db: Session, kb_id: str, @@ -268,18 +265,14 @@ def find_documents_by_kb_id( except Exception as e: return [] -''''上传文件''' + async def memory_konwledges_up( kb_id: str, parent_id: str, create_data: file_schema.CustomTextFileCreate, - db: Session = Depends(get_db), - current_user: SimpleUser = None, # 修改为SimpleUser + db: Session, + current_user: SimpleUser, ): - # 如果没有提供current_user,则创建一个默认的 - if current_user is None: - current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60") - content_bytes = create_data.content.encode('utf-8') file_size = len(content_bytes) print(f"file size: {file_size} byte") @@ -350,8 +343,6 @@ async def memory_konwledges_up( return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") -'''添加新块''' - async def create_document_chunk( kb_id: uuid.UUID, @@ -417,7 +408,7 @@ async def create_document_chunk( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询文档块失败: {error_msg}" ) - + sort_id = sort_id + 1 # 5. 创建文档块 @@ -450,6 +441,7 @@ async def create_document_chunk( return success(data=chunk, msg="文档块创建成功") + async def write_rag(end_user_id, message, user_rag_memory_id): """ 将消息写入 RAG 知识库 @@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id): detail=f"知识库ID格式无效: {user_rag_memory_id}" ) - db_gen = get_db() - db = next(db_gen) - - try: + with get_db_context() as db: create_data = CustomTextFileCreate(title=end_user_id, content=message) current_user = SimpleUser(user_rag_memory_id) # 检查文档是否已存在 document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") - print('======',document) + print('======', document) api_logger.info(f"查找文档结果: document_id={document}") if document is not None: # 文档已存在,直接添加新块 @@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") return result - finally: - # 确保数据库会话被关闭 - db.close() \ No newline at end of file diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index db5051d2..8bacc112 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping -from app.services.implicit_memory_service import ImplicitMemoryService -from app.services.memory_base_service import MemoryBaseService, MemoryTransService +from app.services.memory_base_service import MemoryBaseService from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st from app.core.language_utils import validate_language from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt - from app.db import get_db from app.repositories.end_user_repository import EndUserRepository # 验证语言参数 @@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st if end_user_id: try: # 获取数据库会话并查询用户信息 - db = next(get_db()) - try: + with get_db_context() as db: repo = EndUserRepository(db) end_user = repo.get_by_id(uuid.UUID(end_user_id)) if end_user and end_user.other_name: @@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") else: logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") - finally: - db.close() + except Exception as e: logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")