diff --git a/api/app/core/config.py b/api/app/core/config.py index 5f4f91c4..774f4a0f 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -146,6 +146,8 @@ class Settings: # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) + BROKER_URL: str = os.getenv("BROKER_URL", f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{CELERY_BROKER}") + RESULT_BACKEND: str = os.getenv("RESULT_BACKEND", f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{CELERY_BACKEND}") REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) diff --git a/api/app/core/memory/agent/__init__.py b/api/app/core/memory/agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/app/core/memory/agent/config/problem_extension_config.py b/api/app/core/memory/agent/config/problem_extension_config.py new file mode 100644 index 00000000..73a6d779 --- /dev/null +++ b/api/app/core/memory/agent/config/problem_extension_config.py @@ -0,0 +1,46 @@ + +""" +Problem_Extension优化配置 + +在应用启动时应用这些优化配置 +""" + +# 优化配置 +PROBLEM_EXTENSION_CONFIG = { + # 缓存配置 + "cache_enabled": True, + "cache_ttl": 3600, # 1小时 + + # 超时配置 + "llm_timeout": 8.0, # 8秒超时 + "max_retries": 1, # 最多重试1次 + + # 批处理配置 + "max_questions_per_batch": 10, + "batch_timeout": 15.0, + + # 性能监控 + "monitoring_enabled": True, + "slow_query_threshold": 10.0, # 10秒为慢查询 + + # 连接池配置 + "client_pool_size": 3, + + # 简化模式 + "use_simplified_prompt": True, + "skip_history_for_simple_queries": True, +} + +def apply_optimizations(): + """应用优化配置""" + import os + + # 设置环境变量 + for key, value in PROBLEM_EXTENSION_CONFIG.items(): + env_key = f"PROBLEM_EXTENSION_{key.upper()}" + os.environ[env_key] = str(value) + + print("✅ Problem_Extension优化配置已应用") + +if __name__ == "__main__": + apply_optimizations() diff --git a/api/app/core/memory/agent/langgraph_graph/__init__.py b/api/app/core/memory/agent/langgraph_graph/__init__.py deleted file mode 100644 index a0596e38..00000000 --- a/api/app/core/memory/agent/langgraph_graph/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -LangGraph Graph package for memory agent. - -This package provides the LangGraph workflow orchestrator with modular -node implementations, routing logic, and state management. - -Package structure: -- read_graph: Main graph factory for read operations -- write_graph: Main graph factory for write operations -- nodes: LangGraph node implementations -- routing: State routing logic -- state: State management utilities -""" -from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph - -__all__ = ['make_read_graph'] \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py index 4e808919..231a167c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py @@ -4,7 +4,7 @@ LangGraph node implementations. This module contains custom node implementations for the LangGraph workflow. """ -from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode -from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message - -__all__ = ["ToolExecutionNode", "create_input_message"] +# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode +# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message +# +# __all__ = ["ToolExecutionNode", "create_input_message"] diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py new file mode 100644 index 00000000..6595a2ce --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -0,0 +1,16 @@ +from app.core.memory.agent.utils.llm_tools import ReadState, WriteState + + +def content_input_node(state: ReadState) -> ReadState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} + +def content_input_write(state: WriteState) -> WriteState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py deleted file mode 100644 index 3eed497f..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Input node for LangGraph workflow entry point. - -This module provides the create_input_message function which processes initial -user input with multimodal support and creates the first tool call message. -""" - -import logging -import re -import uuid -from datetime import datetime -from typing import Any, Dict - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage - -logger = logging.getLogger(__name__) - - -async def create_input_message( - state: Dict[str, Any], - tool_name: str, - session_id: str, - search_switch: str, - apply_id: str, - group_id: str, - multimodal_processor: MultimodalProcessor, - memory_config: MemoryConfig, -) -> Dict[str, Any]: - """ - Create initial tool call message from user input. - - This function: - 1. Extracts the last message content from state - 2. Processes multimodal inputs (images/audio) using the multimodal processor - 3. Generates a unique message ID - 4. Extracts namespace from session_id - 5. Handles verified_data extraction for backward compatibility - 6. Returns AIMessage with complete tool_calls structure - - Args: - state: LangGraph state dictionary containing messages - tool_name: Name of the tool to invoke (typically "Split_The_Problem") - session_id: Session identifier (format: "call_id_{namespace}") - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - multimodal_processor: Processor for handling image/audio inputs - memory_config: MemoryConfig object containing all configuration - - Returns: - State update with AIMessage containing tool_call - - Examples: - >>> state = {"messages": [HumanMessage(content="What is AI?")]} - >>> result = await create_input_message( - ... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config - ... ) - >>> result["messages"][0].tool_calls[0]["name"] - 'Split_The_Problem' - """ - messages = state.get("messages", []) - - # Extract last message content - if messages: - last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1]) - else: - logger.warning("[create_input_message] No messages in state, using empty string") - last_message = "" - - logger.debug(f"[create_input_message] Original input: {last_message[:100]}...") - - # Process multimodal input (images/audio) - try: - processed_content = await multimodal_processor.process_input(last_message) - if processed_content != last_message: - logger.info( - f"[create_input_message] Multimodal processing converted input " - f"from {len(last_message)} to {len(processed_content)} chars" - ) - last_message = processed_content - except Exception as e: - logger.error( - f"[create_input_message] Multimodal processing failed: {e}", - exc_info=True - ) - # Continue with original content - - # Generate unique message ID - uuid_str = uuid.uuid4() - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Extract namespace from session_id - # Expected format: "call_id_{namespace}" or similar - try: - namespace = str(session_id).split('_id_')[1] - except (IndexError, AttributeError): - logger.warning( - f"[create_input_message] Could not extract namespace from session_id: {session_id}" - ) - namespace = "unknown" - - # Handle verified_data extraction (backward compatibility) - # This regex-based extraction is kept for compatibility with existing data formats - if 'verified_data' in str(last_message): - try: - messages_last = str(last_message).replace('\\n', '').replace('\\', '') - query_match = re.findall(r'"query": "(.*?)",', messages_last) - if query_match: - last_message = query_match[0] - logger.debug( - f"[create_input_message] Extracted query from verified_data: {last_message}" - ) - except Exception as e: - logger.warning( - f"[create_input_message] Failed to extract query from verified_data: {e}" - ) - - # Construct tool call message - tool_call_id = f"{session_id}_{uuid_str}" - - logger.info( - f"[create_input_message] Creating tool call for '{tool_name}' " - f"with ID: {tool_call_id}" - ) - - # Build tool arguments - tool_args = { - "sentence": last_message, - "sessionid": session_id, - "messages_id": str(uuid_str), - "search_switch": search_switch, - "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, - } - - return { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": tool_name, - "args": tool_args, - "id": tool_call_id - }] - ) - ] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py new file mode 100644 index 00000000..0c68a47e --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -0,0 +1,237 @@ +import json +import time +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class ProblemNodeService(LLMServiceMixin): + """问题处理节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +problem_service = ProblemNodeService() + +async def Split_The_Problem(state: ReadState) -> ReadState: + """问题分解节点""" + # 从状态中获取数据 + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='problem_breakdown_prompt.jinja2', + operation_name='split_the_problem', + history=history, + sentence=content + ) + + try: + # 使用优化的LLM服务 + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + # 添加更详细的日志记录 + logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") + + # 验证结构化响应 + if not structured or not hasattr(structured, 'root'): + logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") + split_result = json.dumps([], ensure_ascii=False) + elif not structured.root: + logger.warning("Split_The_Problem: 结构化响应的root为空") + split_result = json.dumps([], ensure_ascii=False) + else: + split_result = json.dumps( + [item.model_dump() for item in structured.root], + ensure_ascii=False + ) + + split_result_dict = [] + for index, item in enumerate(json.loads(split_result)): + split_data = { + "id": f"Q{index+1}", + "question": item['extended_question'], + "type": item['type'], + "reason": item['reason'] + } + split_result_dict.append(split_data) + + logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") + + result = { + "context": split_result, + "original": content, + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": split_result_dict, + "original_query": content + } + } + + except Exception as e: + logger.error( + f"Split_The_Problem failed: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "content_length": len(content), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Split_The_Problem error details: {error_details}") + + # 创建默认的空结果 + result = { + "context": json.dumps([], ensure_ascii=False), + "original": content, + "error": str(e), + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": [], + "original_query": content, + "error": error_details + } + } + + # 返回更新后的状态,包含spit_context字段 + return {"spit_data": result} + +async def Problem_Extension(state: ReadState) -> ReadState: + """问题扩展节点""" + # 获取原始数据和分解结果 + start = time.time() + content = state.get('data', '') + data = state.get('spit_data', '')['context'] + group_id = state.get('group_id', '') + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + memory_config = state.get('memory_config', None) + + databasets = {} + try: + data = json.loads(data) + for i in data: + databasets[i['extended_question']] = i['type'] + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.error(f"Problem_Extension: 数据解析失败: {e}") + # 使用空字典作为fallback + databasets = {} + data = [] + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='Problem_Extension_prompt.jinja2', + operation_name='problem_extension', + history=history, + questions=databasets + ) + + try: + # 使用优化的LLM服务 + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") + + # 验证结构化响应 + if not response_content or not hasattr(response_content, 'root'): + logger.warning("Problem_Extension: 结构化响应为空或格式不正确") + aggregated_dict = {} + elif not response_content.root: + logger.warning("Problem_Extension: 结构化响应的root为空") + aggregated_dict = {} + else: + # Aggregate results by original question + aggregated_dict = {} + for item in response_content.root: + try: + key = getattr(item, "original_question", None) or ( + item.get("original_question") if isinstance(item, dict) else None + ) + value = getattr(item, "extended_question", None) or ( + item.get("extended_question") if isinstance(item, dict) else None + ) + if not key or not value: + logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}") + continue + aggregated_dict.setdefault(key, []).append(value) + except Exception as item_error: + logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}") + continue + + logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组") + + except Exception as e: + logger.error( + f"LLM call failed for Problem_Extension: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "questions_count": len(databasets), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Problem_Extension error details: {error_details}") + aggregated_dict = {} + + logger.info("Problem extension") + logger.info(f"Problem extension result: {aggregated_dict}") + + # Emit intermediate output for frontend + print(time.time() - start) + result = { + "context": aggregated_dict, + "original": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "problem_extension", + "title": "问题扩展", + "data": aggregated_dict, + "original_query": content, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return {"problem_extension": result} + + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py new file mode 100644 index 00000000..14f8fa8b --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -0,0 +1,417 @@ +# ===== 标准库 ===== +import asyncio +import json +import os + +# ===== 第三方库 ===== +from langchain.agents import create_agent +from langchain_openai import ChatOpenAI +from app.core.logging_config import get_agent_logger +from app.db import get_db, get_db_context + +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService + +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + COUNTState, + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) +from app.core.memory.agent.langgraph_graph.tools.tool import ( + create_hybrid_retrieval_tool_sync, + create_time_retrieval_tool, + extract_tool_message_content, +) + +from app.core.rag.nlp.search import knowledge_retrieval + +logger = get_agent_logger(__name__) +db = next(get_db()) + + + +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config +async def rag_knowledge(state,question): + kb_config = await rag_config(state) + group_id = state.get('group_id', '') + user_rag_memory_id=state.get("user_rag_memory_id",'') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception : + retrieval_knowledge=[] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge,clean_content,cleaned_query,raw_results + + +async def llm_infomation(state: ReadState) -> ReadState: + memory_config = state.get('memory_config', None) + model_id = memory_config.llm_model_id + tenant_id = memory_config.tenant_id + + # 使用现有的 memory_config 而不是重新查询数据库 + # 或者使用线程安全的数据库访问 + with get_db_context() as db: + result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) + result_pydantic = model_schema.ModelConfig.model_validate(result_orm) + return result_pydantic + + +async def clean_databases(data) -> str: + """ + 简化的数据库搜索结果清理函数 + + Args: + data: 搜索结果数据 + + Returns: + 清理后的内容字符串 + """ + try: + # 解析JSON字符串 + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError: + return data + + if not isinstance(data, dict): + return str(data) + + # 获取结果数据 + # with open("搜索结果.json","w",encoding='utf-8') as f: + # f.write(json.dumps(data, indent=4, ensure_ascii=False)) + results = data.get('results', data) + if not isinstance(results, dict): + return str(results) + + # 收集所有内容 + content_list = [] + + # 处理重排序结果 + reranked = results.get('reranked_results', {}) + if reranked: + for category in ['summaries', 'statements', 'chunks', 'entities']: + items = reranked.get(category, []) + if isinstance(items, list): + content_list.extend(items) + # 处理时间搜索结果 + time_search = results.get('time_search', {}) + if time_search: + if isinstance(time_search, dict): + statements = time_search.get('statements', time_search.get('time_search', [])) + if isinstance(statements, list): + content_list.extend(statements) + elif isinstance(time_search, list): + content_list.extend(time_search) + + # 提取文本内容 + text_parts = [] + for item in content_list: + if isinstance(item, dict): + text = item.get('statement') or item.get('content', '') + if text: + text_parts.append(text) + elif isinstance(item, str): + text_parts.append(item) + + + return '\n'.join(text_parts).strip() + + except Exception as e: + logger.error(f"clean_databases failed: {e}", exc_info=True) + return str(data) + + +async def retrieve_nodes(state: ReadState) -> ReadState: + + ''' + + 模型信息 + ''' + + problem_extension=state.get('problem_extension', '')['context'] + storage_type=state.get('storage_type', '') + user_rag_memory_id=state.get('user_rag_memory_id', '') + group_id=state.get('group_id', '') + memory_config = state.get('memory_config', None) + original=state.get('data', '') + problem_list=[] + for key,values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 + async def process_question_nodes(idx, question): + try: + # Prepare search parameters based on storage type + search_params = { + "group_id": group_id, + "question": question, + "return_raw_results": True + } + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search( + **search_params, memory_config=memory_config + ) + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j!=['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve':dup_databases} + + + + +async def retrieve(state: ReadState) -> ReadState: + # 从state中获取group_id + import time + start=time.time() + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + databases_anser = [] + + async def get_llm_info(): + with get_db_context() as db: # 使用同步数据库上下文管理器 + config_service = MemoryConfigService(db) + return await llm_infomation(state) + llm_config = await get_llm_info() + api_key_obj = llm_config.api_keys[0] + api_key = api_key_obj.api_key + api_base = api_key_obj.api_base + model_name = api_key_obj.model_name + llm = ChatOpenAI( + model=model_name, + api_key=api_key, + base_url=api_base, + temperature=0.2, + ) + + time_retrieval_tool = create_time_retrieval_tool(group_id) + search_params = { "group_id": group_id, "return_raw_results": True } + hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + agent = create_agent( + llm, + tools=[time_retrieval_tool,hybrid_retrieval], + system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" + ) + + # 创建异步任务处理单个问题 + import asyncio + + # 在模块级别定义信号量,限制最大并发数 + SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 + + async def process_question(idx, question): + async with SEMAPHORE: # 限制并发 + try: + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + cleaned_query = question + # 使用 asyncio 在线程池中运行同步的 agent.invoke + import asyncio + response = await asyncio.get_event_loop().run_in_executor( + None, + lambda: agent.invoke({"messages": question}) + ) + tool_results = extract_tool_message_content(response) + if tool_results == None: + raw_results = [] + clean_content = '' + else: + raw_results = tool_results['content'] + clean_content = await clean_databases(raw_results) + + try: + raw_results = raw_results['results'] + except Exception: + raw_results = [] + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + import asyncio + tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j != ['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + # with open('retrieve_text.json', 'w') as f: + # json.dump(dup_databases, f, indent=4) + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve': dup_databases} + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py new file mode 100644 index 00000000..7b727da5 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -0,0 +1,303 @@ + + +import time + +from app.core.logging_config import get_agent_logger, log_time +from app.db import get_db + +from app.core.memory.agent.models.summary_models import ( + RetrieveSummaryResponse, + SummaryResponse, +) +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +logger = get_agent_logger(__name__) +db_session = next(get_db()) + +class SummaryNodeService(LLMServiceMixin): + """总结节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +summary_service = SummaryNodeService() + +async def summary_history(state: ReadState) -> ReadState: + group_id = state.get("group_id", '') + history = await SessionService(store).get_history(group_id, group_id, group_id) + return history + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + """ + 增强的summary_llm函数,包含更好的错误处理和数据验证 + """ + data = state.get("data", '') + + # 构建系统提示词 + if str(search_mode) == "0": + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + data=retrieve_info, + query=data + ) + else: + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + query=data, + history=history, + retrieve_info=retrieve_info + ) + try: + # 使用优化的LLM服务进行结构化输出 + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) + # 验证结构化响应 + if structured is None: + logger.warning(f"LLM返回None,使用默认回答") + return "信息不足,无法回答" + + # 根据操作类型提取答案 + if operation_name == "summary": + aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" + else: + # 处理RetrieveSummaryResponse + if hasattr(structured, 'data') and structured.data: + aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" + else: + logger.warning(f"结构化响应缺少data字段") + aimessages = "信息不足,无法回答" + + # 验证答案不为空 + if not aimessages or aimessages.strip() == "": + aimessages = "信息不足,无法回答" + + return aimessages + + except Exception as e: + logger.error(f"结构化输出失败: {e}", exc_info=True) + + # 尝试非结构化输出作为fallback + try: + logger.info("尝试非结构化输出作为fallback") + response = await summary_service.call_llm_simple( + state=state, + db_session=db_session, + system_prompt=system_prompt, + fallback_message="信息不足,无法回答" + ) + + if response and response.strip(): + # 简单清理响应 + cleaned_response = response.strip() + # 移除可能的JSON标记 + if cleaned_response.startswith('```'): + lines = cleaned_response.split('\n') + cleaned_response = '\n'.join(lines[1:-1]) + + return cleaned_response + else: + return "信息不足,无法回答" + + except Exception as fallback_error: + logger.error(f"Fallback也失败: {fallback_error}") + return "信息不足,无法回答" + +async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + data = state.get("data", '') + group_id = state.get("group_id", '') + await SessionService(store).save_session( + user_id=group_id, + query=data, + apply_id=group_id, + group_id=group_id, + ai_response=aimessages + ) + await SessionService(store).cleanup_duplicates() + logger.info(f"sessionid: {aimessages} 写入成功") +async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: + storage_type=state.get("storage_type",'') + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + input_summary = { + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "input_summary", + "title": "快速答案", + "summary": aimessages, + "query": data, + "raw_results": raw_results, + "search_mode": "quick_search", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + retrieve={ + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "retrieval_summary", + "title":"快速检索", + "summary": aimessages, + "query": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return input_summary,retrieve + +async def Input_Summary(state: ReadState) -> ReadState: + start=time.time() + storage_type=state.get("storage_type",'') + memory_config = state.get('memory_config', None) + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + group_id=state.get("group_id", '') + logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + history = await summary_history( state) + search_params = { + "group_id": group_id, + "question": data, + "return_raw_results": True + } + + try: + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + except Exception as e: + logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + retrieve_info, question, raw_results = "", data, [] + + + try: + # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', + # 'input_summary',RetrieveSummaryResponse) + # logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}") + summary_result = await summary_prompt(state, retrieve_info, retrieve_info) + summary = summary_result[0] + except Exception as e: + logger.error( f"Input_Summary failed: {e}", exc_info=True ) + summary= { + "status": "fail", + "summary_result": "信息不足,无法回答", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "error": str(e) + } + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('检索', duration) + return {"summary":summary} + +async def Retrieve_Summary(state: ReadState)-> ReadState: + retrieve=state.get("retrieve", '') + history = await summary_history( state) + import json + with open("检索.json","w",encoding='utf-8') as f: + f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) + retrieve=retrieve.get("Expansion_issue", []) + start=time.time() + retrieve_info_str=[] + for data in retrieve: + if data=='': + retrieve_info_str='' + else: + for key, value in data.items(): + if key=='Answer_Small': + for i in value: + retrieve_info_str.append(i) + retrieve_info_str=list(set(retrieve_info_str)) + retrieve_info_str='\n'.join(retrieve_info_str) + + aimessages=await summary_llm(state,history,retrieve_info_str, + 'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + logger.info(f"Summary after retrieval: {aimessages}") + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + + +async def Summary(state: ReadState)-> ReadState: + start=time.time() + query = state.get("data", '') + verify=state.get("verify", '') + verify_expansion_issue=verify.get("verified_data", '') + retrieve_info_str='' + for data in verify_expansion_issue: + for key, value in data.items(): + if key=='answer_small': + for i in value: + retrieve_info_str+=i+'\n' + history=await summary_history(state) + + data = { + "query": query, + "history": history, + "retrieve_info": retrieve_info_str + } + aimessages=await summary_llm(state,history,data, + 'summary_prompt.jinja2','summary',SummaryResponse,0) + + + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + try: + duration = time.time() - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + +async def Summary_fails(state: ReadState)-> ReadState: + storage_type=state.get("storage_type", '') + user_rag_memory_id=state.get("user_rag_memory_id", '') + result= { + "status": "success", + "summary_result": "没有相关数据", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + return {"summary":result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py deleted file mode 100644 index 4727fb9c..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Tool execution node for LangGraph workflow. - -This module provides the ToolExecutionNode class which wraps tool execution -with parameter transformation logic using the ParameterBuilder service. -""" - -import logging -import time -from typing import Any, Callable, Dict - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_content_payload, - extract_tool_call_id, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage -from langgraph.prebuilt import ToolNode - -logger = logging.getLogger(__name__) - - -class ToolExecutionNode: - """ - Custom LangGraph node that wraps tool execution with parameter transformation. - - This node extracts content from previous tool results, transforms parameters - based on tool type using ParameterBuilder, and invokes the tool with the - correct argument structure. - - Attributes: - tool_node: LangGraph ToolNode wrapping the actual tool - id: Node identifier for message IDs - tool_name: Name of the tool being executed - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - memory_config: MemoryConfig object containing all configuration - """ - - def __init__( - self, - tool: Callable, - node_id: str, - namespace: str, - search_switch: str, - apply_id: str, - group_id: str, - parameter_builder: ParameterBuilder, - storage_type: str, - user_rag_memory_id: str, - memory_config: MemoryConfig, - ): - """ - Initialize the tool execution node. - - Args: - tool: The tool function to execute - node_id: Identifier for this node (used in message IDs) - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - storage_type: Storage type for the workspace - user_rag_memory_id: User RAG memory identifier - memory_config: MemoryConfig object containing all configuration - """ - self.tool_node = ToolNode([tool]) - self.id = node_id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.parameter_builder = parameter_builder - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - self.memory_config = memory_config - - logger.info( - f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'" - ) - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute the tool with transformed parameters. - - This method: - 1. Extracts the last message from state - 2. Extracts tool call ID using state extractors - 3. Extracts content payload using state extractors - 4. Builds tool arguments using parameter builder - 5. Constructs AIMessage with tool_calls - 6. Invokes the tool and returns the result - - Args: - state: LangGraph state dictionary - - Returns: - Updated state with tool result in messages - """ - messages = state.get("messages", []) - logger.debug( self.tool_name) - - if not messages: - logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state") - return {"messages": [AIMessage(content="Error: No messages in state")]} - - last_message = messages[-1] - logger.debug( - f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}" - ) - - try: - # Extract tool call ID using state extractors - tool_call_id = extract_tool_call_id(last_message) - logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}") - - except ValueError as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}" - ) - return {"messages": [AIMessage(content=f"Error: {str(e)}")]} - - try: - # Extract content payload using state extractors - content = extract_content_payload(last_message) - logger.debug( - f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}" - ) - # Log raw message content for debugging - if hasattr(last_message, 'content'): - raw = last_message.content - logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}") - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}", - exc_info=True - ) - content = {} - - try: - # Build tool arguments using parameter builder - tool_args = self.parameter_builder.build_tool_args( - tool_name=self.tool_name, - content=content, - tool_call_id=tool_call_id, - search_switch=self.search_switch, - apply_id=self.apply_id, - group_id=self.group_id, - memory_config=self.memory_config, - storage_type=self.storage_type, - user_rag_memory_id=self.user_rag_memory_id, - ) - logger.debug( - f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}" - ) - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}", - exc_info=True - ) - return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]} - - # Construct tool input message - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": f"{self.id}_{tool_call_id}", - }] - ) - ] - } - - try: - # Invoke the tool - result = await self.tool_node.ainvoke(tool_input) - - logger.debug( - f"[ToolExecutionNode] {self.id} - Tool execution completed" - ) - - # Check for error in tool response - error_entry = None - if result and "messages" in result: - for msg in result["messages"]: - if hasattr(msg, 'content'): - try: - import json - content = msg.content - if isinstance(content, str): - parsed = json.loads(content) - if isinstance(parsed, dict) and "error" in parsed: - error_msg = parsed["error"] - logger.warning( - f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}" - ) - error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id} - except (json.JSONDecodeError, TypeError): - pass - - # Return result with error tracking if error was found - if error_entry: - result["errors"] = [error_entry] - - return result - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}", - exc_info=True - ) - # Track error in state and return error message - from langchain_core.messages import ToolMessage - error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id} - return { - "messages": [ - ToolMessage( - content=f"Error executing tool: {str(e)}", - tool_call_id=f"{self.id}_{tool_call_id}" - ) - ], - "errors": [error_entry] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py new file mode 100644 index 00000000..f3a39afb --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -0,0 +1,85 @@ + +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class VerificationNodeService(LLMServiceMixin): + """验证节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +verification_service = VerificationNodeService() + +async def Verify_prompt(state: ReadState,messages_deal): + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + data = state.get('data', '') + Verify_result = { + "status": messages_deal.split_result, + "verified_data": messages_deal.expansion_issue, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "verification", + "title": "Data Verification", + "result": messages_deal.split_result, + "reason": messages_deal.reason, + "query": data, + "verified_count": len(messages_deal.expansion_issue), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + return Verify_result +async def Verify(state: ReadState): + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + + retrieve = state.get("retrieve", '') + retrieve = retrieve.get("Expansion_issue", []) + messages = { + "Query": content, + "Expansion_issue": retrieve + } + + system_prompt = await verification_service.template_service.render_template( + template_name='split_verify_prompt.jinja2', + operation_name='split_verify_prompt', + history=history, + sentence=messages + ) + + # 使用优化的LLM服务 + structured = await verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "split_result": "fail", + "expansion_issue": [], + "reason": "验证失败" + } + ) + + result = await Verify_prompt(state, structured) + return {"verify": result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py new file mode 100644 index 00000000..0202621e --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -0,0 +1,53 @@ + +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.utils.write_tools import write +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def write_node(state: WriteState) -> WriteState: + """ + Write data to the database/file system. + + Args: + ctx: FastMCP context for dependency injection + content: Data content to write + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + memory_config: MemoryConfig object containing all configuration + + Returns: + dict: Contains 'status', 'saved_to', and 'data' fields + """ + content=state.get('data','') + group_id=state.get('group_id','') + memory_config=state.get('memory_config', '') + try: + result=await write( + content=content, + user_id=group_id, + apply_id=group_id, + group_id=group_id, + memory_config=memory_config, + ) + print('-----------') + print(result) + print('-----------') + logger.info(f"Write completed successfully! Config: {memory_config.config_name}") + + write_result= { + "status": "success", + "data": content, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name, + } + return {"write_result":write_result} + + + except Exception as e: + logger.error(f"Data_write failed: {e}", exc_info=True) + write_result= { + "status": "error", + "message": str(e), + } + return {"write_result": write_result} diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index c29b5d86..19011a5f 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -1,469 +1,177 @@ -import json -import os -import re -import time -import warnings +#!/usr/bin/env python3 from contextlib import asynccontextmanager -from typing import Literal -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.nodes import ( - ToolExecutionNode, - create_input_message, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from langchain_core.messages import AIMessage -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.constants import END, START +from langchain_core.messages import HumanMessage +from langgraph.constants import START, END from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode - -logger = get_agent_logger(__name__) - -warnings.filterwarnings("ignore", category=RuntimeWarning) -load_dotenv() -redishost=os.getenv("REDISHOST") -redisport=os.getenv('REDISPORT') -redisdb=os.getenv('REDISDB') -redispassword=os.getenv('REDISPASSWORD') -counter = COUNTState(limit=3) - -# Update loop count in workflow -async def update_loop_count(state): - """Update loop counter""" - current_count = state.get("loop_count", 0) - return {"loop_count": current_count + 1} -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - messages = state["messages"] +from app.db import get_db +from app.services.memory_config_service import MemoryConfigService - # Add boundary check - if not messages: - return END - counter.add(1) # Increment by 1 +from app.core.memory.agent.utils.llm_tools import ReadState +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node +from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( + Split_The_Problem, + Problem_Extension, +) +from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( + retrieve, +) +from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( + Input_Summary, + Retrieve_Summary, + Summary_fails, + Summary, +) +from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify +from app.core.memory.agent.langgraph_graph.routing.routers import ( + Split_continue, + Retrieve_continue, + Verify_continue, +) - loop_count = counter.get_total() - logger.debug(f"[should_continue] Current loop count: {loop_count}") - - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"Status tools: {status_tools}") - - if "success" in status_tools: - counter.reset() - return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Maximum loop count is 3 - return "content_input" - else: - counter.reset() - return "Summary_fails" - else: - # Add default return value to avoid returning None - counter.reset() - return "Summary" # Default based on business requirements - - -def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Add default return value to avoid returning None - return 'Retrieve_Summary' # Default based on business logic - - -def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - logger.debug(f"Split_continue state: {state}") - - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '2': - return 'Input_Summary' - return 'Split_The_Problem' # Default case - - -class ProblemExtensionNode: - def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""): - self.tool_node = ToolNode([tool]) - self.id = id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - - async def __call__(self, state): - messages = state["messages"] - last_message = messages[-1] if messages else "" - logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}") - if self.tool_name == 'Input_Summary': - tool_call = re.findall("'id': '(.*?)'", str(last_message))[0] - else: - tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1] - - # Try to extract actual content payload from previous tool result - raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message) - extracted_payload = None - # Capture ToolMessage content field (supports single/double quotes), avoid greedy matching - m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S) - if m: - extracted_payload = m.group(1) - else: - # Fallback: use raw string directly - extracted_payload = raw_msg - - # Try to parse content as JSON first - try: - content = json.loads(extracted_payload) - except Exception: - # Try to extract JSON fragment from text and parse - parsed = None - candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S) - for cand in candidates: - try: - parsed = json.loads(cand) - break - except Exception: - continue - # If still fails, use raw string as content - content = parsed if parsed is not None else extracted_payload - - # Build correct parameters based on tool name - tool_args = {} - - if self.tool_name == "Verify": - # Verify tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Retrieve": - # Retrieve tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary": - # Summary tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary_fails": - # Summary_fails tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == 'Input_Summary': - tool_args["context"] = str(last_message) - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - tool_args["storage_type"] = getattr(self, 'storage_type', "") - tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "") - elif self.tool_name == 'Retrieve_Summary': - # Retrieve_Summary expects dict directly, not JSON string - # content might be a JSON string, try to parse it - if isinstance(content, str): - try: - parsed_content = json.loads(content) - # Check if it has a "context" key - if isinstance(parsed_content, dict) and "context" in parsed_content: - tool_args["context"] = parsed_content["context"] - else: - tool_args["context"] = parsed_content - except json.JSONDecodeError: - # If parsing fails, wrap the string - tool_args["context"] = {"content": content} - elif isinstance(content, dict): - # Check if content has a "context" key that needs unwrapping - if "context" in content: - tool_args["context"] = content["context"] - else: - tool_args["context"] = content - else: - tool_args["context"] = {"content": str(content)} - - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - else: - # Other tools use context parameter - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - - - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": self.id + f"{tool_call}", - }] - ) - ] - } - result = await self.tool_node.ainvoke(tool_input) - result_text = str(result) - - return {"messages": [AIMessage(content=result_text)]} @asynccontextmanager -async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None): - """ - Create a read graph workflow for memory operations. - - Args: - namespace: Namespace identifier - tools: MCP tools loaded from session - search_switch: Search mode switch ("0", "1", or "2") - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type (optional) - user_rag_memory_id: User RAG memory ID (optional) - """ - memory = InMemorySaver() - tool = [i.name for i in tools] - logger.info(f"Initializing read graph with tools: {tool}") - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - # Extract tool functions - Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None) - Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None) - Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None) - Verify_ = next((t for t in tools if t.name == "Verify"), None) - Summary_ = next((t for t in tools if t.name == "Summary"), None) - Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None) - Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None) - Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None) - - # Instantiate services - parameter_builder = ParameterBuilder() - multimodal_processor = MultimodalProcessor() - - # Create nodes using new modular components - Split_The_Problem_node = ToolNode([Split_The_Problem_]) - - Problem_Extension_node = ToolExecutionNode( - tool=Problem_Extension_, - node_id="Problem_Extension_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, +async def make_read_graph(): + """创建并返回 LangGraph 工作流""" + try: + # Build workflow graph + workflow = StateGraph(ReadState) + workflow.add_node("content_input", content_input_node) + workflow.add_node("Split_The_Problem", Split_The_Problem) + workflow.add_node("Problem_Extension", Problem_Extension) + workflow.add_node("Input_Summary", Input_Summary) + # workflow.add_node("Retrieve", retrieve_nodes) + workflow.add_node("Retrieve", retrieve) + workflow.add_node("Verify", Verify) + workflow.add_node("Retrieve_Summary", Retrieve_Summary) + workflow.add_node("Summary", Summary) + workflow.add_node("Summary_fails", Summary_fails) + + # 添加边 + workflow.add_edge(START, "content_input") + workflow.add_conditional_edges("content_input", Split_continue) + workflow.add_edge("Input_Summary", END) + workflow.add_edge("Split_The_Problem", "Problem_Extension") + workflow.add_edge("Problem_Extension", "Retrieve") + workflow.add_conditional_edges("Retrieve", Retrieve_continue) + workflow.add_edge("Retrieve_Summary", END) + workflow.add_conditional_edges("Verify", Verify_continue) + workflow.add_edge("Summary_fails", END) + workflow.add_edge("Summary", END) + + + '''-----''' + # workflow.add_edge("Retrieve", END) + + # 编译工作流 + graph = workflow.compile() + yield graph + + except Exception as e: + print(f"创建工作流失败: {e}") + raise + finally: + print("工作流创建完成") + +async def main(): + """主函数 - 运行工作流""" + message = "昨天有什么好看的电影" + group_id = '88a459f5_text09' # 组ID + storage_type = 'neo4j' # 存储类型 + search_switch = '1' # 搜索开关 + user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" ) + import time + start=time.time() + try: + async with make_read_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id + ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + print(f"处理节点: {node_name}") + + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - Retrieve_node = ToolExecutionNode( - tool=Retrieve_, - node_id="Retrieve_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) + + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) + + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) + + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - Verify_node = ToolExecutionNode( - tool=Verify_, - node_id="Verify_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - - Summary_node = ToolExecutionNode( - tool=Summary_, - node_id="Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - Summary_fails_node = ToolExecutionNode( - tool=Summary_fails_, - node_id="Summary_fails_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + # # 过滤掉空值 + # _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] + # + # # 优化搜索结果 + # print("=== 开始优化搜索结果 ===") + # optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + # result=reorder_output_results(optimized_outputs) + # # 保存优化后的结果到文件 + # with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f: + # import json + # f.write(json.dumps(result, indent=4, ensure_ascii=False)) + # + print(f"=== 最终摘要 ===") + print(summary) + + except Exception as e: + import traceback + traceback.print_exc() - Retrieve_Summary_node = ToolExecutionNode( - tool=Retrieve_Summary_, - node_id="Retrieve_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + end=time.time() + print(100*'y') + print(f"总耗时: {end-start}s") + print(100*'y') - Input_Summary_node = ToolExecutionNode( - tool=Input_Summary_, - node_id="Input_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - async def content_input_node(state): - state_search_switch = state.get("search_switch", search_switch) - - tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem" - session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id" - - return await create_input_message( - state=state, - tool_name=tool_name, - session_id=f"{session_prefix}_{namespace}", - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - multimodal_processor=multimodal_processor, - memory_config=memory_config, - ) - - - # Build workflow graph - workflow = StateGraph(ReadState) - workflow.add_node("content_input", content_input_node) - workflow.add_node("Split_The_Problem", Split_The_Problem_node) - workflow.add_node("Problem_Extension", Problem_Extension_node) - workflow.add_node("Retrieve", Retrieve_node) - workflow.add_node("Verify", Verify_node) - workflow.add_node("Summary", Summary_node) - workflow.add_node("Summary_fails", Summary_fails_node) - workflow.add_node("Retrieve_Summary", Retrieve_Summary_node) - workflow.add_node("Input_Summary", Input_Summary_node) - - # Add edges using imported routers - workflow.add_edge(START, "content_input") - workflow.add_conditional_edges("content_input", Split_continue) - workflow.add_edge("Input_Summary", END) - workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") - workflow.add_conditional_edges("Retrieve", Retrieve_continue) - workflow.add_edge("Retrieve_Summary", END) - workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) - workflow.add_edge("Summary", END) - - graph = workflow.compile(checkpointer=memory) - yield graph +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py b/api/app/core/memory/agent/langgraph_graph/routing/__init__.py deleted file mode 100644 index a9366bd0..00000000 --- a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph routing logic.""" - -from app.core.memory.agent.langgraph_graph.routing.routers import ( - Verify_continue, - Retrieve_continue, - Split_continue, -) - -__all__ = [ - "Verify_continue", - "Retrieve_continue", - "Split_continue", -] diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index c8abd544..c0b01be1 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -1,123 +1,62 @@ -""" -Routing functions for LangGraph conditional edges. -This module provides routing functions that determine the next node to execute -based on state values. All functions return Literal types for type safety. -""" - -import logging -import re from typing import Literal -from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch +from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState -logger = logging.getLogger(__name__) -# Global counter for Verify routing +logger = get_agent_logger(__name__) counter = COUNTState(limit=3) - - -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: +def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: """ - Determine routing after Verify node based on verification result. - - This function checks the verification result in the last message and routes to: - - Summary: if verification succeeded - - content_input: if verification failed and retry limit not reached - - Summary_fails: if verification failed and retry limit reached - + Determine routing based on search_switch value. + Args: - state: LangGraph state containing messages - + state: State dictionary containing search_switch + Returns: - Next node name as Literal type + Next node to execute """ - messages = state.get("messages", []) - - # Boundary check - if not messages: - logger.warning("[Verify_continue] No messages in state, defaulting to Summary") - counter.reset() - return "Summary" - - # Increment counter - counter.add(1) + logger.debug(f"Split_continue state: {state}") + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '2': + return 'Input_Summary' + return 'Split_The_Problem' # 默认情况 + +def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: + """ + Determine routing based on search_switch value. + + Args: + state: State dictionary containing search_switch + + Returns: + Next node to execute + """ + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '0': + return 'Verify' + elif search_switch == '1': + return 'Retrieve_Summary' + return 'Retrieve_Summary' # Default based on business logic +def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: + status=state.get('verify', '')['status'] loop_count = counter.get_total() - logger.debug(f"[Verify_continue] Current loop count: {loop_count}") - - # Extract verification result from last message - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"[Verify_continue] Status tools: {status_tools}") - - # Route based on verification result - if "success" in status_tools: + print(status) + if "success" in status: counter.reset() return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Max retry count is 2 + elif "failed" in status: + if loop_count < 2: # Maximum loop count is 3 return "content_input" else: counter.reset() return "Summary_fails" - else: - # Default to Summary if status is unclear - counter.reset() - return "Summary" - - -def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing after Retrieve node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '0': Route to Verify (verification needed) - - search_switch == '1': Route to Retrieve_Summary (direct summary) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - search_switch = extract_search_switch(state) - - logger.debug(f"[Retrieve_continue] search_switch: {search_switch}") - - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Default to Retrieve_Summary - logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary") - return 'Retrieve_Summary' - - -def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing after content_input node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '2': Route to Input_Summary (direct input summary) - - Otherwise: Route to Split_The_Problem (problem decomposition) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - logger.debug(f"[Split_continue] state keys: {state.keys()}") - - search_switch = extract_search_switch(state) - - logger.debug(f"[Split_continue] search_switch: {search_switch}") - - if search_switch == '2': - return 'Input_Summary' - - # Default to Split_The_Problem - return 'Split_The_Problem' + # else: + # # Add default return value to avoid returning None + # counter.reset() + # return "Summary" # Default based on business requirements diff --git a/api/app/core/memory/agent/langgraph_graph/state/__init__.py b/api/app/core/memory/agent/langgraph_graph/state/__init__.py deleted file mode 100644 index 279c6463..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph state management utilities.""" - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_search_switch, - extract_tool_call_id, - extract_content_payload, -) - -__all__ = [ - "extract_search_switch", - "extract_tool_call_id", - "extract_content_payload", -] diff --git a/api/app/core/memory/agent/langgraph_graph/state/extractors.py b/api/app/core/memory/agent/langgraph_graph/state/extractors.py deleted file mode 100644 index f5a32f5d..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/extractors.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -State extraction utilities for type-safe access to LangGraph state values. - -This module provides utility functions for extracting values from LangGraph state -dictionaries with proper error handling and sensible defaults. -""" - -import json -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - -def extract_search_switch(state: dict) -> Optional[str]: - """ - Extract search_switch from state or messages. - """ - - search_switch = state.get("search_switch") - - if search_switch is not None: - return str(search_switch) - - # Try to extract from messages - messages = state.get("messages", []) - if not messages: - return None - - # 从最新的消息开始查找 - for message in reversed(messages): - # 尝试从 tool_calls 中提取 - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - # 从 tool_call 的 args 中提取 - if "args" in tool_call and isinstance(tool_call["args"], dict): - search_switch = tool_call["args"].get("search_switch") - if search_switch is not None: - return str(search_switch) - # 直接从 tool_call 中提取 - search_switch = tool_call.get("search_switch") - if search_switch is not None: - return str(search_switch) - - # 尝试从 content 中提取(如果是 JSON 格式) - if hasattr(message, "content"): - try: - import json - if isinstance(message.content, str): - content_data = json.loads(message.content) - if isinstance(content_data, dict): - search_switch = content_data.get("search_switch") - if search_switch is not None: - return str(search_switch) - except (json.JSONDecodeError, ValueError): - pass - - return None - - -def extract_tool_call_id(message: Any) -> str: - """ - Extract tool call ID from message using structured attributes. - - This function extracts the tool call ID from a message object, handling both - direct attribute access and tool_calls list structures. - - Args: - message: Message object (typically ToolMessage or AIMessage) - - Returns: - Tool call ID as string - - Raises: - ValueError: If tool call ID cannot be extracted - - Examples: - >>> message = ToolMessage(content="...", tool_call_id="call_123") - >>> extract_tool_call_id(message) - 'call_123' - """ - # Try direct attribute access for ToolMessage - if hasattr(message, "tool_call_id"): - tool_call_id = message.tool_call_id - if tool_call_id: - return str(tool_call_id) - - # Try extracting from tool_calls list for AIMessage - if hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "id" in tool_call: - return str(tool_call["id"]) - - # Try extracting from id attribute - if hasattr(message, "id"): - message_id = message.id - if message_id: - return str(message_id) - - # If all else fails, raise an error - raise ValueError(f"Could not extract tool call ID from message: {type(message)}") - - -def extract_content_payload(message: Any) -> Any: - """ - Extract content payload from ToolMessage, parsing JSON if needed. - - This function extracts the content from a message and attempts to parse it as JSON - if it appears to be a JSON string. It handles various message formats and provides - sensible fallbacks. - - Args: - message: Message object (typically ToolMessage) - - Returns: - Parsed content (dict, list, or str) - - Examples: - >>> message = ToolMessage(content='{"key": "value"}') - >>> extract_content_payload(message) - {'key': 'value'} - - >>> message = ToolMessage(content='plain text') - >>> extract_content_payload(message) - 'plain text' - """ - # Extract raw content - # For ToolMessages (responses from tools), extract from content - if hasattr(message, "content"): - raw_content = message.content - logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}") - - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(raw_content, list): - for block in raw_content: - if isinstance(block, dict) and block.get('type') == 'text': - raw_content = block.get('text', '') - logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}") - break - - # If content is empty and this is an AIMessage with tool_calls, - # extract from args (this handles the initial tool call from content_input) - if not raw_content and hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "args" in tool_call: - return tool_call["args"] - else: - raw_content = str(message) - - # If content is already a dict or list, return it directly - if isinstance(raw_content, (dict, list)): - logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}") - return raw_content - - # Try to parse as JSON - if isinstance(raw_content, str): - # First, try direct JSON parsing - try: - parsed = json.loads(raw_content) - logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - pass - - # If that fails, try to extract JSON from the string - # This handles cases where the content is embedded in a larger string - import re - json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL) - for candidate in json_candidates: - try: - parsed = json.loads(candidate) - logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - continue - - # If all parsing attempts fail, return the raw content - logger.info(f"extract_content_payload: returning raw content (parsing failed)") - return raw_content diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py new file mode 100644 index 00000000..ce6d5dd4 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -0,0 +1,320 @@ +import asyncio +import json +from datetime import datetime, timedelta + + +from langchain.tools import tool +from pydantic import BaseModel, Field + + +from app.core.memory.src.search import ( + search_by_temporal, + search_by_keyword_temporal, +) + +def extract_tool_message_content(response): + """从agent响应中提取ToolMessage内容和工具名称""" + messages = response.get('messages', []) + + for message in messages: + if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): + # 这是一个ToolMessage + tool_content = message.content + tool_name = None + + # 尝试获取工具名称 + if hasattr(message, 'name'): + tool_name = message.name + elif hasattr(message, 'tool_name'): + tool_name = message.tool_name + + try: + # 解析JSON内容 + parsed_content = json.loads(tool_content) + return { + 'tool_name': tool_name, + 'content': parsed_content + } + except json.JSONDecodeError: + # 如果不是JSON格式,直接返回内容 + return { + 'tool_name': tool_name, + 'content': tool_content + } + + return None + + +class TimeRetrievalInput(BaseModel): + """时间检索工具的输入模式""" + context: str = Field(description="用户输入的查询内容") + group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + +def create_time_retrieval_tool(group_id: str): + """ + 创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + """ + + def clean_temporal_result_fields(data): + """ + 清理时间搜索结果中不需要的字段,并修改结构 + + Args: + data: 要清理的数据 + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', + 'valid_at', 'invalid_at', 'statement_ids' + } + + if isinstance(data, dict): + cleaned = {} + for key, value in data.items(): + if key == 'statements' and isinstance(value, dict) and 'statements' in value: + # 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]} + cleaned_value = clean_temporal_result_fields(value) + # 进一步将内部的 statements 改为 time_search + if 'statements' in cleaned_value: + cleaned['results'] = { + 'time_search': cleaned_value['statements'] + } + else: + cleaned['results'] = cleaned_value + elif key not in fields_to_remove: + cleaned[key] = clean_temporal_result_fields(value) + return cleaned + elif isinstance(data, list): + return [clean_temporal_result_fields(item) for item in data] + else: + return data + + @tool + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: + """ + 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询上下文内容 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - group_id_param: 组ID(可选,用于覆盖默认组ID) + - clean_output: 是否清理输出中的元数据字段 + -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + # 使用传入的参数或默认值 + actual_group_id = group_id_param or group_id + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + + # 基本时间搜索 + results = await search_by_temporal( + group_id=actual_group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=10 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + @tool + def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: + """ + 优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询内容 + - days_back: 向前搜索的天数,默认7天 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - clean_output: 是否清理输出中的元数据字段 + - end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") + + # 关键词时间搜索 + results = await search_by_keyword_temporal( + query_text=context, + group_id=group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=15 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + return TimeRetrievalWithGroupId + + +def create_hybrid_retrieval_tool_async(memory_config, **search_params): + """ + 创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数,包含group_id, limit, include等 + """ + + def clean_result_fields(data): + """ + 递归清理结果中不需要的字段 + + Args: + data: 要清理的数据(可能是字典、列表或其他类型) + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + } + + if isinstance(data, dict): + # 对字典进行清理 + cleaned = {} + for key, value in data.items(): + if key not in fields_to_remove: + cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 + return cleaned + elif isinstance(data, list): + # 对列表中的每个元素进行清理 + return [clean_result_fields(item) for item in data] + else: + # 其他类型直接返回 + return data + + @tool + async def HybridSearch( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + rerank_alpha: float = 0.6, + use_forgetting_rerank: bool = False, + use_llm_rerank: bool = False, + clean_output: bool = True # 新增:是否清理输出字段 + ) -> str: + """ + 优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + rerank_alpha: 重排序权重参数 + use_forgetting_rerank: 是否使用遗忘重排序 + use_llm_rerank: 是否使用LLM重排序 + clean_output: 是否清理输出中的元数据字段 + """ + try: + # 导入run_hybrid_search函数 + from app.core.memory.src.search import run_hybrid_search + + # 合并参数,优先使用传入的参数 + final_params = { + "query_text": context, + "search_type": search_type, + "group_id": group_id or search_params.get("group_id"), + "limit": limit or search_params.get("limit", 10), + "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), + "output_path": None, # 不保存到文件 + "memory_config": memory_config, + "rerank_alpha": rerank_alpha, + "use_forgetting_rerank": use_forgetting_rerank, + "use_llm_rerank": use_llm_rerank + } + + # 执行混合检索 + raw_results = await run_hybrid_search(**final_params) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_result_fields(raw_results) + else: + cleaned_results = raw_results + + # 格式化返回结果 + formatted_results = { + "search_query": context, + "search_type": search_type, + "results": cleaned_results + } + + return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) + + except Exception as e: + error_result = { + "error": f"混合检索失败: {str(e)}", + "search_query": context, + "search_type": search_type, + "timestamp": datetime.now().isoformat() + } + return json.dumps(error_result, ensure_ascii=False, indent=2) + + return HybridSearch + + +def create_hybrid_retrieval_tool_sync(memory_config, **search_params): + """ + 创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数 + """ + @tool + def HybridSearchSync( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + clean_output: bool = True + ) -> str: + """ + 优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + clean_output: 是否清理输出中的元数据字段 + """ + async def _async_search(): + # 创建异步工具并执行 + async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) + return await async_tool.ainvoke({ + "context": context, + "search_type": search_type, + "limit": limit, + "group_id": group_id, + "clean_output": clean_output + }) + + return asyncio.run(_async_search()) + + return HybridSearchSync \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index ae333e84..5a6f1e28 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,30 +1,32 @@ + import asyncio -import json import sys import warnings from contextlib import asynccontextmanager -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage + +from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode + + +from app.db import get_db +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write +from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) - logger = get_agent_logger(__name__) if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - @asynccontextmanager -async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig): +async def make_write_graph(): """ Create a write graph workflow for memory operations. - + Args: user_id: User identifier tools: MCP tools loaded from session @@ -32,43 +34,8 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me group_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - logger.info("Loading MCP tools: %s", [t.name for t in tools]) - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - data_write_tool = next((t for t in tools if t.name == "Data_write"), None) - - if not data_write_tool: - logger.error("Data_write tool not found", exc_info=True) - raise ValueError("Data_write tool not found") - - write_node = ToolNode([data_write_tool]) - - async def call_model(state): - messages = state["messages"] - last_message = messages[-1] - content = last_message[1] if isinstance(last_message, tuple) else last_message.content - - # Call Data_write directly with memory_config - write_params = { - "content": content, - "apply_id": apply_id, - "group_id": group_id, - "user_id": user_id, - "memory_config": memory_config, - } - logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}") - - write_result = await data_write_tool.ainvoke(write_params) - - if isinstance(write_result, dict): - result_content = write_result.get("data", str(write_result)) - else: - result_content = str(write_result) - logger.info("Write content: %s", result_content) - return {"messages": [AIMessage(content=result_content)]} - workflow = StateGraph(WriteState) - workflow.add_node("content_input", call_model) + workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "content_input") workflow.add_edge("content_input", "save_neo4j") @@ -76,5 +43,45 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me graph = workflow.compile() - yield graph + + +async def main(): + """主函数 - 运行工作流""" + message = "今天周一" + group_id = 'new_2025test1103' # 组ID + + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" + ) + try: + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j'==node_name: + massages=node_data + massages=massages.get('write_result')['status'] + print(massages) # | 更新数据: {node_data} + + except Exception as e: + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/__init__.py b/api/app/core/memory/agent/mcp_server/__init__.py deleted file mode 100644 index efd03773..00000000 --- a/api/app/core/memory/agent/mcp_server/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -MCP Server package for memory agent. - -This package provides the FastMCP server implementation with context-based -dependency injection for tool functions. - -Package structure: -- server: FastMCP server initialization and context setup -- tools: MCP tool implementations -- models: Pydantic response models -- services: Business logic services -""" -# from app.core.memory.agent.mcp_server.server import ( -# mcp, -# initialize_context, -# main, -# get_context_resource -# ) - -# # Import tools to register them (but don't export them) -# from app.core.memory.agent.mcp_server import tools - -# __all__ = [ -# 'mcp', -# 'initialize_context', -# 'main', -# 'get_context_resource', -# ] \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/mcp_instance.py b/api/app/core/memory/agent/mcp_server/mcp_instance.py deleted file mode 100644 index 3a2eeb78..00000000 --- a/api/app/core/memory/agent/mcp_server/mcp_instance.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -MCP Server Instance - -This module contains the FastMCP server instance that is shared across all modules. -It's in a separate file to avoid circular import issues. -""" -from mcp.server.fastmcp import FastMCP - -# Initialize FastMCP server instance -# This instance is shared across all tool modules -mcp = FastMCP('data_flow') diff --git a/api/app/core/memory/agent/mcp_server/server.py b/api/app/core/memory/agent/mcp_server/server.py deleted file mode 100644 index 26f24824..00000000 --- a/api/app/core/memory/agent/mcp_server/server.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -MCP Server initialization with FastMCP context setup. - -This module initializes the FastMCP server and registers shared resources -in the context for dependency injection into tool functions. -""" -import os -import sys - -from app.core.config import settings -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.services.search_service import SearchService -from app.core.memory.agent.mcp_server.services.session_service import SessionService -from app.core.memory.agent.mcp_server.services.template_service import TemplateService -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.redis_tool import store - -logger = get_agent_logger(__name__) - - -def get_context_resource(ctx, resource_name: str): - """ - Helper function to retrieve a resource from the FastMCP context. - - Args: - ctx: FastMCP Context object (passed to tool functions) - resource_name: Name of the resource to retrieve - - Returns: - The requested resource - - Raises: - AttributeError: If the resource doesn't exist - - Example: - @mcp.tool() - async def my_tool(ctx: Context): - template_service = get_context_resource(ctx, 'template_service') - llm_client = get_context_resource(ctx, 'llm_client') - """ - if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None: - raise RuntimeError("Context does not have fastmcp attribute") - - if not hasattr(ctx.fastmcp, resource_name): - raise AttributeError( - f"Resource '{resource_name}' not found in context. " - f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}" - ) - - return getattr(ctx.fastmcp, resource_name) - - -def initialize_context(): - """ - Initialize and register shared resources in FastMCP context. - - This function sets up all shared resources that will be available - to tool functions via dependency injection through the context parameter. - - Resources are stored as attributes on the FastMCP instance and can be - accessed via ctx.fastmcp in tool functions. - - Resources registered: - - session_store: RedisSessionStore for session management - - llm_client: LLM client for structured API calls - - app_settings: Application settings (renamed to avoid conflict with FastMCP settings) - - template_service: Service for template rendering - - search_service: Service for hybrid search - - session_service: Service for session operations - """ - try: - # Register Redis session store - logger.info("Registering session_store in context") - mcp.session_store = store - - # Note: LLM client is NOT loaded at server startup - # It should be loaded dynamically when needed, with config_id passed explicitly - # to make_write_graph or make_read_graph functions - logger.info("LLM client will be loaded dynamically with config_id when needed") - mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id - - # Register application settings (renamed to avoid conflict with FastMCP's settings) - logger.info("Registering app_settings in context") - mcp.app_settings = settings - - # Register template service - template_root = PROJECT_ROOT_ + '/agent/utils/prompt' - # logger.info(f"Registering template_service in context with root: {template_root}") - template_service = TemplateService(template_root) - mcp.template_service = template_service - - # Register search service - # logger.info("Registering search_service in context") - search_service = SearchService() - mcp.search_service = search_service - - # Register session service - # logger.info("Registering session_service in context") - session_service = SessionService(store) - mcp.session_service = session_service - - # logger.info("All context resources registered successfully") - - except Exception as e: - logger.error(f"Failed to initialize context: {e}", exc_info=True) - raise - - -def main(): - """ - Main entry point for the MCP server. - - Initializes context and starts the server with SSE transport. - """ - try: - logger.info("Starting MCP server initialization") - # Initialize context resources - initialize_context() - - # Import and register tools (imports trigger tool registration) - from app.core.memory.agent.mcp_server.tools import ( # noqa: F401 - data_tools, - problem_tools, - retrieval_tools, - summary_tools, - verification_tools, - ) - - # Tools are registered via imports above - - # Get MCP port from environment (default: 8081) - mcp_port = int(os.getenv("MCP_PORT", "8081")) - logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Configure DNS rebinding protection for Docker container compatibility - from mcp.server.fastmcp.server import TransportSecuritySettings - - # Disable DNS rebinding protection to allow Docker container hostnames - # This allows containers to connect using service names like 'mcp-server' - mcp.settings.transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=False, - ) - logger.info("DNS rebinding protection: disabled for Docker container compatibility") - - # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Run the server with SSE transport for HTTP connections - import uvicorn - app = mcp.sse_app() - uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info") - - except Exception as e: - logger.error(f"Failed to start MCP server: {e}", exc_info=True) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/agent/mcp_server/tools/__init__.py b/api/app/core/memory/agent/mcp_server/tools/__init__.py deleted file mode 100644 index 5ce04ef3..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -MCP Tools module. - -This module contains all MCP tool implementations organized by functionality. - -Tools are organized into the following modules: -- problem_tools: Question segmentation and extension -- retrieval_tools: Database and context retrieval -- verification_tools: Data verification -- summary_tools: Summarization and summary retrieval -- data_tools: Data type differentiation and writing -""" - -# Import all tool modules to register them with the MCP server -from . import problem_tools -from . import retrieval_tools -from . import verification_tools -from . import summary_tools -from . import data_tools - -__all__ = [ - 'problem_tools', - 'retrieval_tools', - 'verification_tools', - 'summary_tools', - 'data_tools', -] diff --git a/api/app/core/memory/agent/mcp_server/tools/data_tools.py b/api/app/core/memory/agent/mcp_server/tools/data_tools.py deleted file mode 100644 index 631f7fd7..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/data_tools.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Data Tools for data type differentiation and writing. - -This module contains MCP tools for distinguishing data types and writing data. -""" - -import os - -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.retrieval_models import ( - DistinguishTypeResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.write_tools import write -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Data_type_differentiation( - ctx: Context, - context: str, - memory_config: MemoryConfig, -) -> dict: - """ - Distinguish the type of data (read or write). - - Args: - ctx: FastMCP context for dependency injection - context: Text to analyze for type differentiation - memory_config: MemoryConfig object containing LLM configuration - - Returns: - dict: Contains 'context' with the original text and 'type' field - """ - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - - # Get LLM client from memory_config using factory pattern - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='distinguish_types_prompt.jinja2', - operation_name='status_typle', - user_query=context - ) - except Exception as e: - logger.error( - f"Template rendering failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "type": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=DistinguishTypeResponse - ) - - result = structured.model_dump() - - # Add context to result - result["context"] = context - - return result - - except Exception as e: - logger.error( - f"LLM call failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": f"LLM call failed: {str(e)}" - } - - except Exception as e: - logger.error( - f"Data_type_differentiation failed: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": str(e) - } - - -@mcp.tool() -async def Data_write( - ctx: Context, - content: str, - user_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Write data to the database/file system. - - Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'status', 'saved_to', and 'data' fields - """ - try: - # Ensure output directory exists - os.makedirs("data_output", exist_ok=True) - file_path = os.path.join("data_output", "user_data.csv") - - # Write data - clients are constructed inside write() from memory_config - await write( - content=content, - user_id=user_id, - apply_id=apply_id, - group_id=group_id, - memory_config=memory_config, - ) - logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - - return { - "status": "success", - "saved_to": file_path, - "data": content, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name, - } - - except Exception as e: - logger.error(f"Data_write failed: {e}", exc_info=True) - return { - "status": "error", - "message": str(e), - } diff --git a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py b/api/app/core/memory/agent/mcp_server/tools/problem_tools.py deleted file mode 100644 index 49812e38..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Problem Tools for question segmentation and extension. - -This module contains MCP tools for breaking down and extending user questions. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.problem_models import ( - ProblemBreakdownResponse, - ProblemExtensionResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Split_The_Problem( - ctx: Context, - sentence: str, - sessionid: str, - messages_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Segment the dialogue or sentence into sub-problems. - - Args: - ctx: FastMCP context for dependency injection - sentence: Original sentence to split - sessionid: Session identifier - messages_id: Message identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'context' (JSON string of split results) and 'original' sentence - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Extract user ID from session - user_id = session_service.resolve_user_id(sessionid) - - # Get conversation history - history = await session_service.get_history(user_id, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='problem_breakdown_prompt.jinja2', - operation_name='split_the_problem', - history=history, - sentence=sentence - ) - except Exception as e: - logger.error( - f"Template rendering failed for Split_The_Problem: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemBreakdownResponse - ) - - # Handle RootModel response with .root attribute access - if structured is None: - # LLM returned None, use empty list as fallback - split_result = json.dumps([], ensure_ascii=False) - elif hasattr(structured, 'root') and structured.root is not None: - split_result = json.dumps( - [item.model_dump() for item in structured.root], - ensure_ascii=False - ) - elif isinstance(structured, list): - # Fallback: treat structured itself as the list - split_result = json.dumps( - [item.model_dump() for item in structured], - ensure_ascii=False - ) - else: - # Last resort: use empty list - split_result = json.dumps([], ensure_ascii=False) - - except Exception as e: - logger.error( - f"LLM call failed for Split_The_Problem: {e}", - exc_info=True - ) - split_result = json.dumps([], ensure_ascii=False) - - logger.info("Problem splitting") - logger.info(f"Problem split result: {split_result}") - - # Emit intermediate output for frontend - result = { - "context": split_result, - "original": sentence, - "_intermediate": { - "type": "problem_split", - "data": json.loads(split_result) if split_result else [], - "original_query": sentence - } - } - - return result - - except Exception as e: - logger.error( - f"Split_The_Problem failed: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem splitting', duration) - - -@mcp.tool() -async def Problem_Extension( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Extend the problem with additional sub-questions. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing split problem results - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'context' (aggregated questions) and 'original' question - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID from usermessages - from app.core.memory.agent.utils.messages_tool import Resolve_username - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Process context to extract questions - extent_quest, original = await Problem_Extension_messages_deal(context) - - # Format questions for template rendering - questions_formatted = [] - for msg in extent_quest: - if msg.get("role") == "user": - questions_formatted.append(msg.get("content", "")) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='Problem_Extension_prompt.jinja2', - operation_name='problem_extension', - history=history, - questions=questions_formatted - ) - except Exception as e: - logger.error( - f"Template rendering failed for Problem_Extension: {e}", - exc_info=True - ) - return { - "context": {}, - "original": original, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - response_content = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemExtensionResponse - ) - - # Aggregate results by original question - aggregated_dict = {} - for item in response_content.root: - key = getattr(item, "original_question", None) or ( - item.get("original_question") if isinstance(item, dict) else None - ) - value = getattr(item, "extended_question", None) or ( - item.get("extended_question") if isinstance(item, dict) else None - ) - if not key or not value: - continue - aggregated_dict.setdefault(key, []).append(value) - - except Exception as e: - logger.error( - f"LLM call failed for Problem_Extension: {e}", - exc_info=True - ) - aggregated_dict = {} - - logger.info("Problem extension") - logger.info(f"Problem extension result: {aggregated_dict}") - - # Emit intermediate output for frontend - result = { - "context": aggregated_dict, - "original": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "problem_extension", - "data": aggregated_dict, - "original_query": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - return result - - except Exception as e: - logger.error( - f"Problem_Extension failed: {e}", - exc_info=True - ) - return { - "context": {}, - "original": context.get("original", ""), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem extension', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py b/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py deleted file mode 100644 index db18ba04..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Retrieval Tools for database and context retrieval. - -This module contains MCP tools for retrieving data using hybrid search. -""" - -import os -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import ( - deduplicate_entries, - merge_to_key_value_pairs, -) -from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal -from app.core.rag.nlp.search import knowledge_retrieval -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Retrieve( - ctx: Context, - context, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Retrieve data from the database using hybrid search. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary or string containing query information - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'context' with Query and Expansion_issue results - """ - kb_config = { - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id": os.getenv('reranker_id'), - "reranker_top_k": 10 - } - start = time.time() - logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}") - - try: - # Extract services from context - search_service = get_context_resource(ctx, 'search_service') - - databases_anser = [] - - # Handle both dict and string context - if isinstance(context, dict): - # Process dict context with extended questions - all_items = [] - logger.info(f"Retrieve: context keys={list(context.keys())}") - content, original = await Retriev_messages_deal(context) - logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}") - logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'") - - if not original: - logger.warning(f"Retrieve: original query is empty! context={context}") - - # Extract all query items from content - # content is like {original_question: [extended_questions...], ...} - for key, values in content.items(): - if isinstance(values, list): - all_items.extend(values) - elif isinstance(values, str): - all_items.append(values) - elif values is not None: - # Fallback: convert non-empty non-list values to string - all_items.append(str(values)) - - # Execute search for each question - for idx, question in enumerate(all_items): - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": question, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query=question - raw_results=clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results='' - cleaned_query = question - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - databases_anser.append({ - "Query_small": cleaned_query, - "Result_small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": idx + 1, - "total": len(all_items) - } - }) - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for question '{question}': {e}", - exc_info=True - ) - # Continue with empty result for this question - databases_anser.append({ - "Query_small": question, - "Result_small": "" - }) - - # Build initial database data structure - databases_data = { - "Query": original, - "Expansion_issue": databases_anser - } - - # Collect intermediate outputs before deduplication - intermediate_outputs = [] - for item in databases_anser: - if '_intermediate' in item: - intermediate_outputs.append(item['_intermediate']) - - # Deduplicate and merge results - deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) - deduplicated_data_merged = merge_to_key_value_pairs( - deduplicated_data, - 'Query_small', - 'Result_small' - ) - - # Restructure for Verify/Retrieve_Summary compatibility - keys, val = [], [] - for item in deduplicated_data_merged: - for items_key, items_value in item.items(): - keys.append(items_key) - val.append(items_value) - - send_verify = [] - for i, j in zip(keys, val, strict=False): - send_verify.append({ - "Query_small": i, - "Answer_Small": j - }) - - dup_databases = { - "Query": original, - "Expansion_issue": send_verify, - "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs - } - - logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - - else: - # Handle string context (simple query) - query = str(context).strip() - - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query = query - raw_results = clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results = '' - cleaned_query = query - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - # Keep structure for Verify/Retrieve_Summary compatibility - dup_databases = { - "Query": cleaned_query, - "Expansion_issue": [{ - "Query_small": cleaned_query, - "Answer_Small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": 1, - "total": 1 - } - }] - } - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for query '{query}': {e}", - exc_info=True - ) - # Return empty results on failure - dup_databases = { - "Query": query, - "Expansion_issue": [] - } - - logger.info( - f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, " - f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}" - ) - - # Build result with intermediate outputs - result = { - "context": dup_databases, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - # Add intermediate outputs list if they exist - intermediate_outputs = dup_databases.get('_intermediate_outputs', []) - if intermediate_outputs: - result['_intermediates'] = intermediate_outputs - logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result") - else: - logger.warning("No intermediate outputs found in dup_databases") - - return result - - except Exception as e: - logger.error( - f"Retrieve failed: {e}", - exc_info=True - ) - return { - "context": { - "Query": "", - "Expansion_issue": [] - }, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py deleted file mode 100644 index 0f306572..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ /dev/null @@ -1,640 +0,0 @@ -""" -Summary Tools for data summarization. - -This module contains MCP tools for summarizing retrieved data and generating responses. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import os -import re -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.summary_models import ( - RetrieveSummaryResponse, - SummaryResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Summary_messages_deal, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.rag.nlp.search import knowledge_retrieval -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Summary( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize the verified data. - - Args: - ctx: FastMCP context for dependency injection - context: JSON string containing verified data - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Process context to extract answer and query - answer_small, query = await Summary_messages_deal(context) - - - start_time= time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - data = { - "query": query, - "history": history, - "retrieve_info": answer_small - } - - except Exception as e: - logger.error( - f"Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='summary_prompt.jinja2', - operation_name='summary', - data=data, - query=query - ) - except Exception as e: - logger.error( - f"Template rendering failed for Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=SummaryResponse - ) - - aimessages = structured.query_answer or "" - - except Exception as e: - logger.error( - f"LLM call failed for Summary: {e}", - exc_info=True - ) - aimessages = "" - - try: - # Save session - if aimessages != "": - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}", - exc_info=True - ) - return { - "status": "error", - "message": str(e) - } - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after verification: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Summary', duration) - - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - -@mcp.tool() -async def Retrieve_Summary( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize data directly from retrieval results. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing Query and Expansion_issue from Retrieve - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - - - # Handle both 'content' and 'context' keys (LangGraph uses 'content') - logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}") - - if isinstance(context, dict): - if "content" in context: - inner = context["content"] - # If it's a JSON string, parse it - if isinstance(inner, str): - try: - parsed = json.loads(inner) - logger.info("Retrieve_Summary: successfully parsed JSON") - except json.JSONDecodeError: - # Try unescaping first - try: - unescaped = inner.encode('utf-8').decode('unicode_escape') - parsed = json.loads(unescaped) - logger.info("Retrieve_Summary: parsed after unescaping") - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error( - f"Retrieve_Summary: parsing failed even after unescape: {e}" - ) - context_dict = {"Query": "", "Expansion_issue": []} - parsed = None - - if parsed: - # Check if parsed has 'context' wrapper - if isinstance(parsed, dict) and "context" in parsed: - context_dict = parsed["context"] - else: - context_dict = parsed - elif isinstance(inner, dict): - context_dict = inner - else: - context_dict = {"Query": "", "Expansion_issue": []} - elif "context" in context: - context_dict = context["context"] if isinstance(context["context"], dict) else context - else: - context_dict = context - else: - context_dict = {"Query": "", "Expansion_issue": []} - - query = context_dict.get("Query", "") - expansion_issue = context_dict.get("Expansion_issue", []) - - logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}") - logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}") - - # Extract retrieve_info from expansion_issue - retrieve_info = [] - for item in expansion_issue: - # Check for both Answer_Small and Answer_Small (typo) for backward compatibility - answer = None - if isinstance(item, dict): - if "Answer_Small" in item: - answer = item["Answer_Small"] - - - if answer is not None: - # Handle both string and list formats - if isinstance(answer, list): - # Join list of characters/strings into a single string - retrieve_info.append(''.join(str(x) for x in answer)) - elif isinstance(answer, str): - retrieve_info.append(answer) - else: - retrieve_info.append(str(answer)) - - # Join all retrieve_info into a single string - retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else "" - - start_time=time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - except Exception as e: - logger.error( - f"Retrieve_Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='retrieve_summary', - query=query, - history=history, - retrieve_info=retrieve_info_str - ) - except Exception as e: - logger.error( - f"Template rendering failed for Retrieve_Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - - # Handle case where structured response might be None or incomplete - if structured and hasattr(structured, 'data') and structured.data: - aimessages = structured.data.query_answer or "" - else: - logger.warning("Structured response is None or incomplete, using default message") - aimessages = "信息不足,无法回答" - - - # Check for insufficient information response - if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="": - # Save session - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"Retrieve_Summary: LLM call failed: {e}", - exc_info=True - ) - aimessages = "" - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after retrieval: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval summary', duration) - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "retrieval_summary", - "summary": aimessages, - "query": query, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - -@mcp.tool() -async def Input_Summary( - ctx: Context, - context: str, - usermessages: str, - search_switch: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Generate a quick summary for direct input without verification. - - Args: - ctx: FastMCP context for dependency injection - context: String containing the input sentence - usermessages: User messages identifier - search_switch: Search switch value for routing ('2' for summaries only) - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'query_answer' with the summary result - """ - start = time.time() - logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - - try: - # Extract services from context - session_service = get_context_resource(ctx, "session_service") - search_service = get_context_resource(ctx, "search_service") - - # Resolve session ID - sessionid = Resolve_username(usermessages) or "" - sessionid = sessionid.replace('call_id_', '') - - start_time=time.time() - history = await session_service.get_history( - str(sessionid), - str(apply_id), - str(group_id) - ) - end_time=time.time() - logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}") - # Override with empty list for now (as in original) - - # Log the raw context for debugging - logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}") - - # Extract sentence from context - # Context can be a string or might contain the sentence in various formats - try: - # Try to parse as JSON first - if isinstance(context, str) and (context.startswith('{') or context.startswith('[')): - try: - import json - context_dict = json.loads(context) - if isinstance(context_dict, dict): - query = context_dict.get('sentence', context_dict.get('content', context)) - else: - query = context - except json.JSONDecodeError: - # Not valid JSON, try regex - match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context) - query = match.group(1) if match else context - else: - query = context - except Exception as e: - logger.warning(f"Failed to extract query from context: {e}") - query = context - - # Clean query - query = str(query).strip().strip("\"'") - - logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}") - - # Execute search based on search_switch and storage_type - try: - logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}") - - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - - # Retrieval - if search_switch == '2': - search_params["include"] = ["summaries"] - if storage_type == "rag" and user_rag_memory_id: - raw_results = [] - retrieve_info = "" - kb_config={ - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id":os.getenv('reranker_id'), - "reranker_top_k": 10 - } - - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - retrieve_info = '\n\n'.join(retrieval_knowledge) - raw_results=[retrieve_info] - logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}") - except: - retrieve_info='' - raw_results=[''] - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - logger.info("Input_Summary: Using summary for retrieval") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - except Exception as e: - logger.error( - f"Input_Summary: hybrid_search failed, using empty results: {e}", - exc_info=True - ) - retrieve_info, question, raw_results = "", query, [] - - # Return retrieved information directly without LLM processing - # Use the raw retrieved info as the answer - aimessages = retrieve_info if retrieve_info else "信息不足,无法回答" - - logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...") - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "input_summary", - "title": "快速答案", - "summary": aimessages, - "query": query, - "raw_results": raw_results, - "search_mode": "quick_search", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Input_Summary failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "信息不足,无法回答", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) - - -@mcp.tool() -async def Summary_fails( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Handle workflow failure when summary cannot be generated. - - Args: - ctx: FastMCP context for dependency injection - context: Failure context string - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'query_answer' with failure message - """ - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Parse session ID from usermessages - usermessages_parts = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages_parts[:-1]) - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - logger.info("没有相关数据") - logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}") - - return { - "status": "success", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - except Exception as e: - logger.error( - f"Summary_fails failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } diff --git a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py b/api/app/core/memory/agent/mcp_server/tools/verification_tools.py deleted file mode 100644 index cb6af5bd..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Verification Tools for data verification. - -This module contains MCP tools for verifying retrieved data. -""" -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Retrieve_verify_tool_messages_deal, - Verify_messages_deal, -) -from app.core.memory.agent.utils.verify_tool import VerifyTool -from app.schemas.memory_config_schema import MemoryConfig -from jinja2 import Template -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Verify( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Verify the retrieved data. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing query and expansion issues - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'verified_data' with verification results - """ - start = time.time() - - - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Load verification prompt template - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2' - - # Read template file directly (VerifyTool expects raw template content) - from app.core.memory.agent.utils.messages_tool import read_template_file - system_prompt = await read_template_file(file_path) - - - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - - template = Template(system_prompt) - system_prompt = template.render(history=history, sentence=context) - - # Process context to extract query and results - Query_small, Result_small, query = await Verify_messages_deal(context) - - # Build query list for verification - query_list = [] - for query_small, anser in zip(Query_small, Result_small, strict=False): - query_list.append({ - 'Query_small': query_small, - 'Answer_Small': anser - }) - - messages = { - "Query": query, - "Expansion_issue": query_list - } - - - - # Call verification workflow with LLM model ID from memory_config - verify_tool = VerifyTool( - system_prompt=system_prompt, - verify_data=messages, - llm_model_id=str(memory_config.llm_model_id) - ) - verify_result = await verify_tool.verify() - - # Parse LLM verification result with error handling - try: - messages_deal = await Retrieve_verify_tool_messages_deal( - verify_result, - history, - query - ) - except Exception as e: - logger.error( - f"Retrieve_verify_tool_messages_deal parsing failed: {e}", - exc_info=True - ) - # Fallback to avoid 500 errors - messages_deal = { - "data": { - "query": query, - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": history, - } - - logger.info(f"Verification result: {messages_deal}") - - # Emit intermediate output for frontend - return { - "status": "success", - "verified_data": messages_deal, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "verification", - "title": "Data Verification", - "result": messages_deal.get("split_result", "unknown"), - "reason": messages_deal.get("reason", ""), - "query": query, - "verified_count": len(query_list), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Verify failed: {e}", - exc_info=True - ) - return { - "status": "error", - "message": str(e), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "verified_data": { - "data": { - "query": "", - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": [], - } - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Verification', duration) diff --git a/api/app/core/memory/agent/mcp_server/models/__init__.py b/api/app/core/memory/agent/models/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/__init__.py rename to api/app/core/memory/agent/models/__init__.py diff --git a/api/app/core/memory/agent/mcp_server/models/problem_models.py b/api/app/core/memory/agent/models/problem_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/problem_models.py rename to api/app/core/memory/agent/models/problem_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/retrieval_models.py b/api/app/core/memory/agent/models/retrieval_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/retrieval_models.py rename to api/app/core/memory/agent/models/retrieval_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/summary_models.py b/api/app/core/memory/agent/models/summary_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/summary_models.py rename to api/app/core/memory/agent/models/summary_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/verification_models.py b/api/app/core/memory/agent/models/verification_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/verification_models.py rename to api/app/core/memory/agent/models/verification_models.py diff --git a/api/app/core/memory/agent/multimodal/oss_picture.py b/api/app/core/memory/agent/multimodal/oss_picture.py deleted file mode 100644 index b5b4bd6b..00000000 --- a/api/app/core/memory/agent/multimodal/oss_picture.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import sys -import traceback - -import requests - -# from qcloud_cos import CosConfig, CosS3Client -# from qcloud_cos.cos_exception import CosClientError, CosServiceError - -# from config.paths import BASE_DIR -BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0])) - -class OSSUploader: - """对象存储文件上传工具类""" - - def __init__(self, env): - api = { - "test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon", - "prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon" - } - self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon") - self.privacy = "false" - self.headers = { - "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' - 'AppleWebKit/537.36 (KHTML, like Gecko)' - ' Chrome/133.0.6833.84 Safari/537.36' - } - - @staticmethod - def _generate_object_key(file_path, prefix='xhs_'): - """ - 生成对象存储的Key - - :param file_path: 本地文件路径 - :param prefix: 存储前缀,用于分类存储 - :return: 生成的对象Key - """ - # 文件md5值.后缀名 - filename = os.path.basename(file_path) - filename = f"{filename}" - - # 组合成完整的对象Key - return f"{prefix}{filename}" - - def upload_image(self, file_name, prefix='jd_'): - """ - 上传文件到COS并返回可访问的URL - - :param file_url: 文件路径 - :param file_name: 文件名称 - :param media_type: 文件类型 - :param prefix: 存储前缀,用于分类存储 - :return: 文件访问URL - """ - # 检查文件是否存在 - - - - file_path = os.path.join(BASE_DIR, file_name) - - # response = requests.get(url, headers=self.headers, stream=True) - - # if response.status_code == 200: - # with open(file_path, "wb") as f: - # for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大 - # f.write(chunk) - # else: - # raise Exception(f"文件下载失败,{file_name}") - - # 生成对象Key - object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1]) - - try: - upload_response = requests.post( - self.api, - data={ - "privacy": self.privacy, - "fileName": object_key, - } - ) - - if upload_response.status_code != 200: - raise Exception('上传接口请求失败') - resp = upload_response.json() - name = resp["data"]["name"] - file_url = resp["data"]["path"] - policy = resp["data"]["policy"] - with open(file_path, 'rb') as f: - oss_push_resp = requests.post( - policy["host"], - files={ - "key": policy["dir"], - "OSSAccessKeyId": policy["accessid"], - "name": name, - "policy": policy["policy"], - "success_action_status": 200, - "signature": policy["signature"], - "file": f, - } - ) - if oss_push_resp.status_code == 200: - return file_url - raise Exception("OSS上传失败") - except Exception: - raise Exception(f"上传失败: \n{traceback.format_exc()}") - finally: - print('success') - # os.remove(file_path) - - -if __name__ == '__main__': - cos_uploader = OSSUploader("prod") - url =cos_uploader.upload_image('./example01.jpg') - print(url) diff --git a/api/app/core/memory/agent/multimodal/speech_model.py b/api/app/core/memory/agent/multimodal/speech_model.py deleted file mode 100644 index 2df32dd0..00000000 --- a/api/app/core/memory/agent/multimodal/speech_model.py +++ /dev/null @@ -1,121 +0,0 @@ -import asyncio -import re - -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize -from app.core.memory.agent.utils.messages_tool import read_template_file - -import requests -import json -import os -import time -# file_urls = [ -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav", -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav", -# ] -class Vico_recognition: - def __init__(self,file_urls): - self.api_key='' - self.backend_model_name='' - self.api_base='' - self.file_urls=file_urls - - # 提交文件转写任务,包含待转写文件url列表 - async def submit_task(self) -> str: - self.api_key, self.backend_model_name, self.api_base =await Voice_recognize() - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - data = { - "model": self.backend_model_name, - "input": {"file_urls": self.file_urls}, - "parameters": { - "channel_id": [0], - "vocabulary_id": "vocab-Xxxx", - }, - } - # 录音文件转写服务url - service_url = ( - "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription" - ) - response = requests.post( - service_url, headers=headers, data=json.dumps(data) - ) - - # 打印响应内容 - if response.status_code == 200: - return response.json()["output"]["task_id"] - else: - print("task failed!") - print(response.json()) - return None - - async def download_transcription_result(self, transcription_url): - """ - Args: - transcription_url (str): 转写结果文件URL - Returns: - dict: 转写结果内容 - """ - try: - response = requests.get(transcription_url) - response.raise_for_status() - return response.json() - except Exception as e: - print(f"下载转写结果失败: {e}") - return None - - # 循环查询任务状态直到成功 - async def wait_for_complete(self,task_id): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - - pending = True - while pending: - # 查询任务状态服务url - service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" - response = requests.post( - service_url, headers=headers - ) - if response.status_code == 200: - status = response.json()['output']['task_status'] - if status == 'SUCCEEDED': - print("task succeeded!") - pending = False - return response.json()['output']['results'] - elif status == 'RUNNING' or status == 'PENDING': - pass - else: - print("task failed!") - pending = False - else: - print("query failed!") - pending = False - time.sleep(0.1) - async def run(self): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - task_id=await self.submit_task() - result=await self.wait_for_complete(task_id) - result_context=[] - for i in result: - transcription_url=i['transcription_url'] - print(f"转写URL: {transcription_url}") - - # 下载并打印转写内容 - content = await self.download_transcription_result(transcription_url) - if content: - content=json.dumps(content, indent=2, ensure_ascii=False) - context=re.findall(r'"text": "(.*?)"', content) - result_context.append(context[0]) - result=''.join(result_context) - return (result) - - - - diff --git a/api/app/core/memory/agent/mcp_server/services/__init__.py b/api/app/core/memory/agent/services/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/__init__.py rename to api/app/core/memory/agent/services/__init__.py diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py new file mode 100644 index 00000000..6942d421 --- /dev/null +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -0,0 +1,277 @@ +""" +优化的LLM服务类,用于压缩和统一LLM调用 +""" + +import asyncio +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.core.logging_config import get_agent_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.llm_tools.openai_client import OpenAIClient + +T = TypeVar('T', bound=BaseModel) + +logger = get_agent_logger(__name__) + + +class OptimizedLLMService: + """ + 优化的LLM服务类,提供统一的LLM调用接口 + + 特性: + 1. 客户端复用 - 避免重复创建LLM客户端 + 2. 批量处理 - 支持并发处理多个请求 + 3. 错误处理 - 统一的错误处理和降级策略 + 4. 性能优化 - 缓存和连接池优化 + """ + + def __init__(self, db_session: Session): + self.db_session = db_session + self.client_factory = MemoryClientFactory(db_session) + self._client_cache: Dict[str, OpenAIClient] = {} + + def _get_cached_client(self, llm_model_id: str) -> OpenAIClient: + """获取缓存的LLM客户端,避免重复创建""" + if llm_model_id not in self._client_cache: + self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id) + return self._client_cache[llm_model_id] + + async def structured_response( + self, + llm_model_id: str, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 统一的结构化响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}") + + structured = await llm_client.response_structured( + messages=messages, + response_model=response_model + ) + + if structured is None: + logger.warning(f"LLM返回None,使用降级值") + return self._create_fallback_response(response_model, fallback_value) + + return structured + + except Exception as e: + logger.error(f"结构化响应失败: {e}", exc_info=True) + return self._create_fallback_response(response_model, fallback_value) + + async def batch_structured_response( + self, + llm_model_id: str, + requests: List[Dict[str, Any]], + response_model: Type[T], + max_concurrent: int = 5 + ) -> List[T]: + """ + 批量处理结构化响应 + + Args: + llm_model_id: LLM模型ID + requests: 请求列表,每个请求包含system_prompt等参数 + response_model: 响应模型类 + max_concurrent: 最大并发数 + + Returns: + 结构化响应列表 + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_single_request(request: Dict[str, Any]) -> T: + async with semaphore: + return await self.structured_response( + llm_model_id=llm_model_id, + system_prompt=request.get('system_prompt', ''), + response_model=response_model, + user_message=request.get('user_message'), + fallback_value=request.get('fallback_value') + ) + + tasks = [process_single_request(req) for req in requests] + return await asyncio.gather(*tasks) + + async def simple_response( + self, + llm_model_id: str, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 简单的文本响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + response = await llm_client.response(messages=messages) + + if not response or not response.strip(): + return fallback_message + + return response.strip() + + except Exception as e: + logger.error(f"简单响应失败: {e}", exc_info=True) + return fallback_message + + def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T: + """创建降级响应""" + try: + if fallback_value is not None: + if isinstance(fallback_value, response_model): + return fallback_value + elif isinstance(fallback_value, dict): + return response_model(**fallback_value) + + # 尝试创建空的响应模型 + if hasattr(response_model, 'root'): + # RootModel类型 + return response_model([]) + else: + # 普通BaseModel类型 + return response_model() + + except Exception as e: + logger.error(f"创建降级响应失败: {e}") + # 最后的降级策略 + if hasattr(response_model, 'root'): + return response_model([]) + else: + return response_model() + + def clear_cache(self): + """清理客户端缓存""" + self._client_cache.clear() + + +class LLMServiceMixin: + """ + LLM服务混入类,为节点提供便捷的LLM调用方法 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._llm_service: Optional[OptimizedLLMService] = None + + def get_llm_service(self, db_session: Session) -> OptimizedLLMService: + """获取LLM服务实例""" + if self._llm_service is None: + self._llm_service = OptimizedLLMService(db_session) + return self._llm_service + + async def call_llm_structured( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 便捷的结构化LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.structured_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + response_model=response_model, + user_message=user_message, + fallback_value=fallback_value + ) + + async def call_llm_simple( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 便捷的简单LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.simple_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + user_message=user_message, + fallback_message=fallback_message + ) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py b/api/app/core/memory/agent/services/parameter_builder.py similarity index 87% rename from api/app/core/memory/agent/mcp_server/services/parameter_builder.py rename to api/app/core/memory/agent/services/parameter_builder.py index d5305dc6..a58fcf1a 100644 --- a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py +++ b/api/app/core/memory/agent/services/parameter_builder.py @@ -4,22 +4,19 @@ Parameter Builder for constructing tool call arguments. This service provides tool-specific parameter transformation logic to build correct arguments for each tool type. """ - from typing import Any, Dict, Optional - from app.core.logging_config import get_agent_logger -from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class ParameterBuilder: """Service for building tool call arguments based on tool type.""" - + def __init__(self): """Initialize the parameter builder.""" logger.info("ParameterBuilder initialized") - + def build_tool_args( self, tool_name: str, @@ -28,9 +25,8 @@ class ParameterBuilder: search_switch: str, apply_id: str, group_id: str, - memory_config: MemoryConfig, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, + user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: """ Build tool arguments based on tool type. @@ -49,7 +45,6 @@ class ParameterBuilder: search_switch: Search routing parameter apply_id: Application identifier group_id: Group identifier - memory_config: MemoryConfig object containing all configuration storage_type: Storage type for the workspace (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) @@ -60,19 +55,18 @@ class ParameterBuilder: base_args = { "usermessages": tool_call_id, "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, + "group_id": group_id } - + # Always add storage_type and user_rag_memory_id (with defaults if None) base_args["storage_type"] = storage_type if storage_type is not None else "" base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else "" # Tool-specific argument construction - if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]: - # These tools expect dict context + if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']: + # Verify expects dict context return { - "context": content if isinstance(content, dict) else {"content": content}, + "context": content if isinstance(content, dict) else {}, **base_args } diff --git a/api/app/core/memory/agent/mcp_server/services/search_service.py b/api/app/core/memory/agent/services/search_service.py similarity index 75% rename from api/app/core/memory/agent/mcp_server/services/search_service.py rename to api/app/core/memory/agent/services/search_service.py index 47295f87..8a2e7cfe 100644 --- a/api/app/core/memory/agent/mcp_server/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -4,31 +4,21 @@ Search Service for executing hybrid search and processing results. This service provides clean search result processing with content extraction and deduplication. """ - -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Tuple, Optional from app.core.logging_config import get_agent_logger from app.core.memory.src.search import run_hybrid_search from app.core.memory.utils.data.text_utils import escape_lucene_query -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class SearchService: """Service for executing hybrid search and processing results.""" - - def __init__(self, memory_config: "MemoryConfig" = None): - """ - Initialize the search service. - - Args: - memory_config: Optional MemoryConfig for embedding model configuration. - If not provided, must be passed to execute_hybrid_search. - """ - self.memory_config = memory_config + + def __init__(self): + """Initialize the search service.""" logger.info("SearchService initialized") def extract_content_from_result(self, result: dict) -> str: @@ -103,49 +93,40 @@ class SearchService: self, group_id: str, question: str, - limit: int = 15, + limit: int = 5, search_type: str = "hybrid", include: Optional[List[str]] = None, - rerank_alpha: float = 0.6, - activation_boost_factor: float = 0.8, + rerank_alpha: float = 0.4, output_path: str = "search_results.json", return_raw_results: bool = False, - memory_config: "MemoryConfig" = None, + memory_config = None ) -> Tuple[str, str, Optional[dict]]: """ - Execute hybrid search with two-stage ranking. - - Stage 1: Filter by content relevance (BM25 + Embedding) - Stage 2: Rerank by activation values (ACTR) + Execute hybrid search and return clean content. Args: - group_id: Group identifier for filtering + group_id: Group identifier for filtering results question: Search query text - limit: Max results per category (default: 15) - search_type: "hybrid", "keyword", or "embedding" (default: "hybrid") - include: Result types (default: ["statements", "chunks", "entities", "summaries"]) - rerank_alpha: BM25 weight (default: 0.6) - activation_boost_factor: Activation impact on memory strength (default: 0.8) - output_path: JSON output path (default: "search_results.json") - return_raw_results: Return full metadata (default: False) - memory_config: MemoryConfig for embedding model + limit: Maximum number of results to return (default: 5) + search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") + include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) + rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) + output_path: Path to save search results (default: "search_results.json") + return_raw_results: If True, also return the raw search results as third element (default: False) + memory_config: Memory configuration object (required) Returns: - Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results) + Tuple of (clean_content, cleaned_query, raw_results) + raw_results is None if return_raw_results=False """ if include is None: include = ["statements", "chunks", "entities", "summaries"] - - # Use provided memory_config or fall back to instance config - config = memory_config or self.memory_config - if not config: - raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search") - + # Clean query cleaned_query = self.clean_query(question) - + try: - # Execute search using memory_config + # Execute search answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, @@ -153,9 +134,8 @@ class SearchService: limit=limit, include=include, output_path=output_path, - memory_config=config, - rerank_alpha=rerank_alpha, - activation_boost_factor=activation_boost_factor, + memory_config=memory_config, + rerank_alpha=rerank_alpha ) # Extract results based on search type and include parameter diff --git a/api/app/core/memory/agent/mcp_server/services/session_service.py b/api/app/core/memory/agent/services/session_service.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/session_service.py rename to api/app/core/memory/agent/services/session_service.py diff --git a/api/app/core/memory/agent/mcp_server/services/template_service.py b/api/app/core/memory/agent/services/template_service.py similarity index 94% rename from api/app/core/memory/agent/mcp_server/services/template_service.py rename to api/app/core/memory/agent/services/template_service.py index 95223f0b..1bf86375 100644 --- a/api/app/core/memory/agent/mcp_server/services/template_service.py +++ b/api/app/core/memory/agent/services/template_service.py @@ -3,12 +3,22 @@ Template Service for loading and rendering Jinja2 templates. This service provides centralized template management with caching and error handling. """ + import os from functools import lru_cache -from typing import Optional -from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound -from app.core.logging_config import get_agent_logger, log_prompt_rendering +from jinja2 import ( + Environment, + FileSystemLoader, + Template, + TemplateNotFound, +) + +from app.core.logging_config import ( + get_agent_logger, + log_prompt_rendering, +) + logger = get_agent_logger(__name__) diff --git a/api/app/core/memory/agent/utils/__init__.py b/api/app/core/memory/agent/utils/__init__.py deleted file mode 100644 index 2b77e240..00000000 --- a/api/app/core/memory/agent/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Agent utilities.""" - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -__all__ = [ - "MultimodalProcessor", -] diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py new file mode 100644 index 00000000..fddd54f6 --- /dev/null +++ b/api/app/core/memory/agent/utils/llm_client_pool.py @@ -0,0 +1,56 @@ + +import asyncio +from typing import Dict, Optional +from app.core.memory.utils.llm.llm_utils import get_llm_client_fast +from app.db import get_db +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class LLMClientPool: + """LLM客户端连接池""" + + def __init__(self, max_size: int = 5): + self.max_size = max_size + self.pools: Dict[str, asyncio.Queue] = {} + self.active_clients: Dict[str, int] = {} + + async def get_client(self, llm_model_id: str): + """获取LLM客户端""" + if llm_model_id not in self.pools: + self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) + self.active_clients[llm_model_id] = 0 + + pool = self.pools[llm_model_id] + + try: + # 尝试从池中获取客户端 + client = pool.get_nowait() + logger.debug(f"从池中获取LLM客户端: {llm_model_id}") + return client + except asyncio.QueueEmpty: + # 池为空,创建新客户端 + if self.active_clients[llm_model_id] < self.max_size: + db_session = next(get_db()) + client = get_llm_client_fast(llm_model_id, db_session) + self.active_clients[llm_model_id] += 1 + logger.debug(f"创建新LLM客户端: {llm_model_id}") + return client + else: + # 等待可用客户端 + logger.debug(f"等待LLM客户端可用: {llm_model_id}") + return await pool.get() + + async def return_client(self, llm_model_id: str, client): + """归还LLM客户端到池中""" + if llm_model_id in self.pools: + try: + self.pools[llm_model_id].put_nowait(client) + logger.debug(f"归还LLM客户端到池: {llm_model_id}") + except asyncio.QueueFull: + # 池已满,丢弃客户端 + self.active_clients[llm_model_id] -= 1 + logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") + +# 全局客户端池 +llm_client_pool = LLMClientPool() diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index ec22b628..8dd2f1d3 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,40 +1,12 @@ -import asyncio -import json -import logging import os from collections import defaultdict from typing import Annotated, TypedDict -from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.config.config_utils import ( - get_picture_config, - get_voice_config, -) - -# Removed global variable imports - use dependency injection instead -from dotenv import load_dotenv from langchain_core.messages import AnyMessage from langgraph.graph import add_messages -from openai import OpenAI PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -logger = logging.getLogger(__name__) -load_dotenv() - - -async def picture_model_requests(image_url): - ''' - - Args: - image_url: - Returns: - - ''' - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 ' - system_prompt = await read_template_file(file_path) - result = await Picture_recognize(image_url,system_prompt) - return (result) class WriteState(TypedDict): ''' Langgrapg Writing TypedDict @@ -44,39 +16,69 @@ class WriteState(TypedDict): apply_id:str group_id:str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] + memory_config: object + write_result: dict + data:str class ReadState(TypedDict): - ''' - Langgrapg READING TypedDict - name: - id:user id - loop_count:Traverse times - search_switch:type - config_id: configuration id for filtering results - errors: list of errors that occurred during workflow execution - ''' - messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息 - name: str - id: str - loop_count:int + """ + LangGraph 工作流状态定义 + + Attributes: + messages: 消息列表,支持自动追加 + loop_count: 遍历次数 + search_switch: 搜索类型开关 + group_id: 组标识 + config_id: 配置ID,用于过滤结果 + data: 从content_input_node传递的内容数据 + spit_data: 从Split_The_Problem传递的分解结果 + tool_calls: 工具调用请求列表 + tool_results: 工具执行结果列表 + memory_config: 内存配置对象 + """ + messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 + loop_count: int search_switch: str - user_id: str - apply_id: str group_id: str config_id: str - errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] - - + data: str # 新增字段用于传递内容 + spit_data: dict # 新增字段用于传递问题分解结果 + problem_extension:dict + storage_type: str + user_rag_memory_id: str + llm_id: str + embedding_id: str + memory_config: object # 新增字段用于传递内存配置对象 + retrieve:dict + RetrieveSummary: dict + InputSummary: dict + verify: dict + SummaryFails: dict + summary: dict class COUNTState: - ''' - The number of times the workflow dialogue retrieval content has no correct message recall traversal - ''' + """ + 工作流对话检索内容计数器 + + 用于记录工作流对话检索内容没有正确消息召回遍历的次数。 + """ + def __init__(self, limit: int = 5): + """ + 初始化计数器 + + Args: + limit: 最大计数限制,默认为5 + """ self.total: int = 0 # 当前累加值 self.limit: int = limit # 最大上限 - def add(self, value: int = 1): - """累加数字,如果达到上限就保持最大值""" + def add(self, value: int = 1) -> None: + """ + 累加数字,如果达到上限就保持最大值 + + Args: + value: 要累加的值,默认为1 + """ self.total += value print(f"[COUNTState] 当前值: {self.total}") if self.total >= self.limit: @@ -84,21 +86,19 @@ class COUNTState: self.total = self.limit # 达到上限不再增加 def get_total(self) -> int: - """获取当前累加值""" + """ + 获取当前累加值 + + Returns: + 当前累加值 + """ return self.total - def reset(self): + def reset(self) -> None: """手动重置累加值""" self.total = 0 print("[COUNTState] 已重置为 0") - -def merge_to_key_value_pairs(data, query_key, result_key): - grouped = defaultdict(list) - for item in data: - grouped[item[query_key]].append(item[result_key]) - return [{key: values} for key, values in grouped.items()] - def deduplicate_entries(entries): seen = set() deduped = [] @@ -109,70 +109,37 @@ def deduplicate_entries(entries): deduped.append(entry) return deduped +def merge_to_key_value_pairs(data, query_key, result_key): + grouped = defaultdict(list) + for item in data: + grouped[item[query_key]].append(item[result_key]) + return [{key: values} for key, values in grouped.items()] -async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str: +def convert_extended_question_to_question(data): """ - Updated to eliminate global variables in favor of explicit parameters. - + 递归地将数据中的 extended_question 字段转换为 question 字段 + Args: - image_path: Path to image file - PROMPT_TICKET_EXTRACTION: Extraction prompt - picture_model_name: Picture model name (required, no longer from global variables) + data: 要转换的数据(可能是字典、列表或其他类型) + + Returns: + 转换后的数据 """ - try: - model_config = get_picture_config(picture_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base=model_config['api_base'] - - logger.info(f"model_name: {backend_model_name}") - logger.info(f"api_key set: {'yes' if api_key else 'no'}") - logger.info(f"base_url: {model_config['api_base']}") - - client = OpenAI( - api_key=api_key, base_url=api_base, - ) - completion = client.chat.completions.create( - model=backend_model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url":image_path, - }, - {"type": "text", - "text": PROMPT_TICKET_EXTRACTION} - ] - } - ]) - picture_text = completion.choices[0].message.content - picture_text = picture_text.replace('```json', '').replace('```', '') - picture_text = json.loads(picture_text) - return (picture_text['statement']) - -async def Voice_recognize(voice_model_name: str): - """ - Updated to eliminate global variables in favor of explicit parameters. - - Args: - voice_model_name: Voice model name (required, no longer from global variables) - """ - try: - model_config = get_voice_config(voice_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base = model_config['api_base'] - return api_key,backend_model_name,api_base - - + if isinstance(data, dict): + # 创建新字典来存储转换后的数据 + converted = {} + for key, value in data.items(): + if key == 'extended_question': + # 将 extended_question 转换为 question + converted['question'] = convert_extended_question_to_question(value) + else: + # 递归处理其他字段 + converted[key] = convert_extended_question_to_question(value) + return converted + elif isinstance(data, list): + # 递归处理列表中的每个元素 + return [convert_extended_question_to_question(item) for item in data] + else: + # 其他类型直接返回 + return data \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/mcp_tools.py b/api/app/core/memory/agent/utils/mcp_tools.py deleted file mode 100644 index 7ede9843..00000000 --- a/api/app/core/memory/agent/utils/mcp_tools.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -from app.core.config import settings - -def get_mcp_server_config(): - """ - Get the MCP server configuration. - - Uses MCP_SERVER_URL environment variable if set (for Docker), - otherwise falls back to SERVER_IP and MCP_PORT (for local development). - """ - # Get MCP port from environment (default: 8081) - mcp_port = os.getenv("MCP_PORT", "8081") - - # In Docker: MCP_SERVER_URL=http://mcp-server:8081 - # In local dev: uses SERVER_IP (127.0.0.1 or localhost) - mcp_server_url = os.getenv("MCP_SERVER_URL") - - if mcp_server_url: - # Docker environment: use full URL from environment - base_url = mcp_server_url - else: - # Local development: build URL from SERVER_IP and MCP_PORT - base_url = f"http://{settings.SERVER_IP}:{mcp_port}" - - mcp_server_config = { - "data_flow": { - "url": f"{base_url}/sse", - "transport": "sse", - "timeout": 15000, - "sse_read_timeout": 15000, - } - } - return mcp_server_config diff --git a/api/app/core/memory/agent/utils/messages_tool.py b/api/app/core/memory/agent/utils/messages_tool.py deleted file mode 100644 index 769e795a..00000000 --- a/api/app/core/memory/agent/utils/messages_tool.py +++ /dev/null @@ -1,260 +0,0 @@ -import json -import logging -import re -from typing import Any, List - -from app.core.logging_config import get_agent_logger -from langchain_core.messages import AnyMessage - -logger = get_agent_logger(__name__) - - -def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]: - out = [] - for m in msgs: - if hasattr(m, "content"): - out.append({"role": "user", "content": getattr(m, "content", "")}) - elif isinstance(m, dict) and "role" in m and "content" in m: - out.append(m) - else: - out.append({"role": "user", "content": str(m)}) - return out - - -def _extract_content(resp: Any) -> str: - """Extract LLM content and sanitize to raw JSON/text. - - - Supports both object and dict response shapes. - - Removes leading role labels (e.g., "Assistant:"). - - Strips Markdown code fences like ```json ... ```. - - Attempts to isolate the first valid JSON array/object block when extra text is present. - """ - - def _to_text(r: Any) -> str: - try: - # 对象形式: resp.choices[0].message.content - if hasattr(r, "choices") and getattr(r, "choices", None): - msg = r.choices[0].message - if hasattr(msg, "content"): - return msg.content - if isinstance(msg, dict) and "content" in msg: - return msg["content"] - # 字典形式: resp["choices"][0]["message"]["content"] - if isinstance(r, dict): - return r.get("choices", [{}])[0].get("message", {}).get("content", "") - except Exception: - pass - return str(r) - - def _clean_text(text: str) -> str: - s = str(text).strip() - # 移除可能的角色前缀 - s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s) - # 提取 ```json ... ``` 代码块 - m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I) - if m: - s = m.group(1).strip() - # 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段 - if not (s.startswith("{") or s.startswith("[")): - left = s.find("[") - right = s.rfind("]") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - else: - left = s.find("{") - right = s.rfind("}") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - return s - - raw = _to_text(resp) - return _clean_text(raw) - -def Resolve_username(usermessages): - ''' - Extract username - Args: - usermessages: user name - - Returns: - - ''' - usermessages = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages[:-1]) - return sessionid - - -# TODO: USE app.core.memory.src.utils.render_template instead -async def read_template_file(template_path: str) -> str: - """ - 读取模板文件 - - Args: - template_path: 模板文件路径 - - Returns: - 模板内容字符串 - - Note: - 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 - """ - try: - with open(template_path, "r", encoding="utf-8") as f: - return f.read() - except FileNotFoundError: - logger.error(f"模板文件未找到: {template_path}") - raise - except IOError as e: - logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) - raise - - -async def Problem_Extension_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - extent_quest = [] - original = context.get('original', '') - messages = context.get('context', '') - - # Handle empty or non-string messages - if not messages: - return extent_quest, original - - if isinstance(messages, str): - try: - messages = json.loads(messages) - except json.JSONDecodeError: - # If JSON parsing fails, return empty list - return extent_quest, original - - if isinstance(messages, list): - for message in messages: - question = message.get('question', '') - type = message.get('type', '') - extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"}) - - return extent_quest, original - - -async def Retriev_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}") - - if isinstance(context, dict): - logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}") - if 'context' in context or 'original' in context: - content = context.get('context', {}) - original = context.get('original', '') - logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'") - return content, original - - # Return empty defaults if context is not a dict or doesn't have expected keys - logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults") - return {}, '' - -async def Verify_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - - query = context['context']['Query'] - Query_small_list = context['context']['Expansion_issue'] - Result_small = [] - Query_small = [] - for i in Query_small_list: - Result_small.append(i['Answer_Small'][0]) - Query_small.append(i['Query_small']) - return Query_small, Result_small, query - - -async def Summary_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - query = re.findall(r'"query": (.*?),', messages)[0] - query = query.replace('[', '').replace(']', '').strip() - matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages) - answer_small_texts = [] - for m in matches: - try: - parsed = json.loads(m) - for item in parsed: - answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', '')) - except Exception: - answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', '')) - - return answer_small_texts, query - - -async def VerifyTool_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - content_messages = messages.split('"context":')[1].replace('""', '"') - messages = str(content_messages).split("name='Retrieve'")[0] - query = re.findall('"Query": "(.*?)"', messages)[0] - Query_small = re.findall('"Query_small": "(.*?)"', messages) - Result_small = re.findall('"Result_small": "(.*?)"', messages) - return Query_small, Result_small, query - - -async def Retrieve_Summary_messages_deal(context): - pass - - -async def Retrieve_verify_tool_messages_deal(context, history, query): - ''' - Extract data - Args: - context: - Returns: - ''' - results = [] - # 统一转为字符串,避免 None 或非字符串导致正则报错 - text = str(context) - blocks = re.findall(r'\{(.*?)\}', text, flags=re.S) - for block in blocks: - query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block) - answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block) - status = re.search(r'"status"\s*:\s*"([^"]*)"', block) - query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block) - - results.append({ - "query_small": query_small.group(1) if query_small else None, - "answer_small": answer_small.group(1) if answer_small else None, - # 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误 - "status": status.group(1) if status else "", - "query_answer": query_answer.group(1) if query_answer else None - }) - result = [] - for r in results: - # 统一按字符串判定状态,兼容大小写和缺失情况 - status_str = str(r.get('status', '')).strip().lower() - if status_str == 'false': - continue - else: - result.append(r) - split_result = 'failed' if not result else 'success' - result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "", - "history": history} - return result diff --git a/api/app/core/memory/agent/utils/messages_tools.py b/api/app/core/memory/agent/utils/messages_tools.py new file mode 100644 index 00000000..db95319f --- /dev/null +++ b/api/app/core/memory/agent/utils/messages_tools.py @@ -0,0 +1,194 @@ +from typing import List, Dict, Any +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def read_template_file(template_path: str) -> str: + """ + 读取模板文件 + + Args: + template_path: 模板文件路径 + + Returns: + 模板内容字符串 + + Note: + 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 + """ + try: + with open(template_path, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + logger.error(f"模板文件未找到: {template_path}") + raise + except IOError as e: + logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) + raise + +def reorder_output_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 重新排序输出结果,将 retrieval_summary 类型的数据放到最后面 + + Args: + results: 原始输出结果列表 + + Returns: + 重新排序后的结果列表 + """ + retrieval_summaries = [] + other_results = [] + + # 分离 retrieval_summary 和其他类型的结果 + for result in results: + if 'summary' in result.get('type'): + retrieval_summaries.append(result) + else: + other_results.append(result) + + # 将 retrieval_summary 放到最后 + return other_results + retrieval_summaries + +def optimize_search_results(intermediate_outputs): + """ + 优化检索结果,合并多个搜索结果,过滤空结果,统一格式 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 优化后的检索结果列表 + """ + optimized_results = [] + + for item in intermediate_outputs: + if not item or item == [] or item == {}: + continue + + # 检查是否是搜索结果类型 + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + + # 如果 raw_results 为空,跳过 + if not raw_results or raw_results == [] or raw_results == {}: + continue + + # 创建优化后的结果结构 + optimized_item = { + "type": "search_result", + "title": f"检索结果 ({item.get('index', 1)}/{item.get('total', 1)})", + "query": item.get('query', ''), + "raw_results": {}, + "index": item.get('index', 1), + "total": item.get('total', 1) + } + + # 合并所有搜索结果类型到一个 raw_results 中 + merged_raw_results = {} + + # 处理 time_search + if 'time_search' in raw_results and raw_results['time_search']: + merged_raw_results['time_search'] = raw_results['time_search'] + + # 处理 keyword_search + if 'keyword_search' in raw_results and raw_results['keyword_search']: + merged_raw_results['keyword_search'] = raw_results['keyword_search'] + + # 处理 embedding_search + if 'embedding_search' in raw_results and raw_results['embedding_search']: + merged_raw_results['embedding_search'] = raw_results['embedding_search'] + + # 处理 combined_summary + if 'combined_summary' in raw_results and raw_results['combined_summary']: + merged_raw_results['combined_summary'] = raw_results['combined_summary'] + + # 处理 reranked_results + if 'reranked_results' in raw_results and raw_results['reranked_results']: + merged_raw_results['reranked_results'] = raw_results['reranked_results'] + + # 如果合并后的结果不为空,添加到优化结果中 + if merged_raw_results: + optimized_item['raw_results'] = merged_raw_results + optimized_results.append(optimized_item) + else: + # 非搜索结果类型,直接添加 + optimized_results.append(item) + + return optimized_results + + +def merge_multiple_search_results(intermediate_outputs): + """ + 将多个搜索结果合并为一个统一的搜索结果 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 合并后的结果列表 + """ + search_results = [] + other_results = [] + + # 分离搜索结果和其他结果 + for item in intermediate_outputs: + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + # 只保留有内容的搜索结果 + if raw_results and raw_results != [] and raw_results != {}: + search_results.append(item) + else: + other_results.append(item) + + # 如果没有搜索结果,返回原始结果 + if not search_results: + return intermediate_outputs + + # 如果只有一个搜索结果,优化格式后返回 + if len(search_results) == 1: + optimized = optimize_search_results(search_results) + return other_results + optimized + + # 合并多个搜索结果 + merged_raw_results = {} + all_queries = [] + + for result in search_results: + query = result.get('query', '') + if query: + all_queries.append(query) + + raw_results = result.get('raw_results', {}) + + # 合并各种搜索类型的结果 + for search_type in ['time_search', 'keyword_search', 'embedding_search', 'combined_summary', + 'reranked_results']: + if search_type in raw_results and raw_results[search_type]: + if search_type not in merged_raw_results: + merged_raw_results[search_type] = raw_results[search_type] + else: + # 如果是字典类型,需要合并 + if isinstance(raw_results[search_type], dict) and isinstance(merged_raw_results[search_type], dict): + for key, value in raw_results[search_type].items(): + if key not in merged_raw_results[search_type]: + merged_raw_results[search_type][key] = value + elif isinstance(value, list) and isinstance(merged_raw_results[search_type][key], list): + merged_raw_results[search_type][key].extend(value) + elif isinstance(raw_results[search_type], list): + if isinstance(merged_raw_results[search_type], list): + merged_raw_results[search_type].extend(raw_results[search_type]) + else: + merged_raw_results[search_type] = raw_results[search_type] + + # 创建合并后的结果 + if merged_raw_results: + merged_result = { + "type": "search_result", + "title": f"合并检索结果 (共{len(search_results)}个查询)", + "query": " | ".join(all_queries), + "raw_results": merged_raw_results, + "index": 1, + "total": 1 + } + return other_results + [merged_result] + + return other_results diff --git a/api/app/core/memory/agent/utils/model_tool.py b/api/app/core/memory/agent/utils/model_tool.py deleted file mode 100644 index 969a2a91..00000000 --- a/api/app/core/memory/agent/utils/model_tool.py +++ /dev/null @@ -1,38 +0,0 @@ - - -# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# sys.path.insert(0, project_root) - -# load_dotenv() - -# async def llm_client_chat(messages: List[dict]) -> str: -# """使用 OpenAI 兼容接口进行对话,返回内容字符串。""" -# try: -# cfg = get_model_config(SELECTED_LLM_ID) -# rb_config = RedBearModelConfig( -# model_name=cfg["model_name"], -# provider=cfg["provider"], -# api_key=cfg["api_key"], -# base_url=cfg["base_url"], -# ) -# client = OpenAIClient(model_config=rb_config, type_="chat") - -# except Exception as e: -# logger.error(f"获取模型配置失败:{e}") -# err = f"获取模型配置失败:{str(e)}。请检查!!!" -# return err -# try: -# response = await client.chat(messages) -# print(f"model_tool's llm_client_chat response ======>:\n {response}") -# return _extract_content(response) -# # return _extract_content(result) -# except Exception as e: -# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。") -# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。" - -# async def main(image_url): -# await llm_client_chat(image_url) -# -# # 运行主函数 -# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav'])) -# diff --git a/api/app/core/memory/agent/utils/multimodal.py b/api/app/core/memory/agent/utils/multimodal.py deleted file mode 100644 index 0fc52634..00000000 --- a/api/app/core/memory/agent/utils/multimodal.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Multimodal input processor for handling image and audio content. - -This module provides utilities for detecting and processing multimodal inputs -(images and audio files) by converting them to text using appropriate models. -""" - -import logging -from typing import List - -from app.core.memory.agent.multimodal.speech_model import Vico_recognition -from app.core.memory.agent.utils.llm_tools import picture_model_requests - -logger = logging.getLogger(__name__) - - -class MultimodalProcessor: - """ - Processor for handling multimodal inputs (images and audio). - - This class detects image and audio file paths in input content and converts - them to text using appropriate recognition models. - """ - - # Supported file extensions - IMAGE_EXTENSIONS = ['.jpg', '.png'] - AUDIO_EXTENSIONS = [ - 'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov', - 'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv' - ] - - def __init__(self): - """Initialize the multimodal processor.""" - pass - - def is_image(self, content: str) -> bool: - """ - Check if content is an image file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported image extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_image("photo.jpg") - True - >>> processor.is_image("document.pdf") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS) - - def is_audio(self, content: str) -> bool: - """ - Check if content is an audio file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported audio extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_audio("recording.mp3") - True - >>> processor.is_audio("video.mp4") - True - >>> processor.is_audio("document.txt") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS) - - async def process_input(self, content: str) -> str: - """ - Process input content, converting images/audio to text if needed. - - This method detects if the input is an image or audio file and converts - it to text using the appropriate recognition model. If processing fails - or the content is not multimodal, it returns the original content. - - Args: - content: Input string (may be file path or regular text) - - Returns: - Text content (original or converted from image/audio) - - Examples: - >>> processor = MultimodalProcessor() - >>> await processor.process_input("photo.jpg") - "Recognized text from image..." - - >>> await processor.process_input("Hello world") - "Hello world" - """ - if not isinstance(content, str): - logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}") - return str(content) - - try: - # Check for image input - if self.is_image(content): - logger.info(f"[MultimodalProcessor] Detected image input: {content}") - result = await picture_model_requests(content) - logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...") - return result - - # Check for audio input - if self.is_audio(content): - logger.info(f"[MultimodalProcessor] Detected audio input: {content}") - result = await Vico_recognition([content]).run() - logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...") - return result - - except Exception as e: - logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True) - logger.info("[MultimodalProcessor] Falling back to original content") - return content - - # Return original content if not multimodal - return content diff --git a/api/app/core/memory/agent/utils/performance_monitor.py b/api/app/core/memory/agent/utils/performance_monitor.py new file mode 100644 index 00000000..d2d9fdfa --- /dev/null +++ b/api/app/core/memory/agent/utils/performance_monitor.py @@ -0,0 +1,56 @@ + +import time +import json +from collections import defaultdict +from typing import Dict, List +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class ProblemExtensionMonitor: + """Problem_Extension性能监控器""" + + def __init__(self): + self.metrics = defaultdict(list) + self.slow_queries = [] + self.error_count = 0 + + def record_execution(self, duration: float, question_count: int, success: bool): + """记录执行指标""" + self.metrics['durations'].append(duration) + self.metrics['question_counts'].append(question_count) + + if not success: + self.error_count += 1 + + # 记录慢查询(超过10秒) + if duration > 10.0: + self.slow_queries.append({ + 'duration': duration, + 'question_count': question_count, + 'timestamp': time.time() + }) + + def get_stats(self) -> Dict: + """获取统计信息""" + durations = self.metrics['durations'] + if not durations: + return {"message": "暂无数据"} + + return { + "total_executions": len(durations), + "avg_duration": sum(durations) / len(durations), + "max_duration": max(durations), + "min_duration": min(durations), + "slow_queries_count": len(self.slow_queries), + "error_rate": self.error_count / len(durations) if durations else 0, + "recent_slow_queries": self.slow_queries[-5:] # 最近5个慢查询 + } + + def log_stats(self): + """记录统计信息到日志""" + stats = self.get_stats() + logger.info(f"Problem_Extension性能统计: {json.dumps(stats, indent=2)}") + +# 全局监控器实例 +performance_monitor = ProblemExtensionMonitor() diff --git a/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 new file mode 100644 index 00000000..a0e21fbd --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 @@ -0,0 +1,81 @@ + +你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则: + +角色: +- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。 +- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。 +- 如果历史信息或上下文与当前问题无关,可忽略。 + +--- + +### 历史信息参考 +在生成扩展问题时,你可以参考以下历史数据(如果提供): +- 历史对话或任务的主题; +- 历史中出现的关键实体(时间、人物、地点、研究主题等); +- 历史中已解答的问题(避免重复); +- 历史推理链(保持逻辑一致性)。 + +> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 +输入历史信息内容:{{history}} + +## User Input +{% if questions is string %} +{{ questions }} +{% else %} +{% for question in questions %} +- {{ question }} +{% endfor %} +{% endif %} + +需求: +- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。 +- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。 +- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。 +- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。 +- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。 +- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。 +- 子问题数量不超过4个。 +- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 + 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] + 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? + + + +输出要求: +- 仅输出 JSON 数组,不要包含任何解释或代码块。 +- 每个元素包含: + - `original_question`: 原始问题 + - `extended_question`: 扩展后的问题 + - `type`: 类型(事实检索/澄清/定义/比较/行动建议) + - `reason`: 生成该扩展问题的简短理由 +- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。 + +示例: +输入: +[ + "问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳", +] + +输出: +[ + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?", + "type": "多跳", + "reason": "输出原问题的关键要素" + }, + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?", + "type": "多跳", + "reason": "输出原问题的关键要素" + } +] +**Output format** +**CRITICAL JSON FORMATTING REQUIREMENTS:** +1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes +2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") +3. Ensure all JSON strings are properly closed and comma-separated +4. Do not include line breaks within JSON string values + +The output language should always be the same as the input language.{{ json_schema }} diff --git a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 index 1fa71df3..5fbe8574 100644 --- a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 @@ -1,13 +1,10 @@ # 角色 你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。 - # 任务 根据提供的上下文信息回答用户的问题。 - # 输入信息 - 历史对话:{{history}} - 检索信息:{{retrieve_info}} - ## User Query {{query}} diff --git a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 index f4d4665c..d6ad8cab 100644 --- a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 @@ -9,8 +9,8 @@ 3. 判断Answer_Small和Query_Small之间分析出来的关系状态 4. 如果是True保留,否则不要相对应的问题和回答 5. 输出,需要严格按照模版 -输入:{{history}} -历史消息:{"history":{{sentence}}} +输入:{{sentence}} +历史消息:{"history":{{history}}} ### 第一步 获取用户的输入 获取用户的输入提取对应的Query_Small和Answer_Small ### 第二步 分析验证 diff --git a/api/app/core/memory/agent/utils/session_tools.py b/api/app/core/memory/agent/utils/session_tools.py new file mode 100644 index 00000000..b2d4f0ff --- /dev/null +++ b/api/app/core/memory/agent/utils/session_tools.py @@ -0,0 +1,169 @@ +""" +Session Service for managing user sessions and conversation history. + +This service provides clean Redis interactions with error handling and +session management utilities. +""" +from typing import List, Optional + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.redis_tool import RedisSessionStore + + +logger = get_agent_logger(__name__) + + +class SessionService: + """Service for managing user sessions and conversation history.""" + + def __init__(self, store: RedisSessionStore): + """ + Initialize the session service. + + Args: + store: Redis session store instance + """ + self.store = store + logger.info("SessionService initialized") + + def resolve_user_id(self, session_string: str) -> str: + """ + Extract user ID from session string. + + Handles formats like: + - 'call_id_user123' -> 'user123' + - 'prefix_id_user456_suffix' -> 'user456_suffix' + + Args: + session_string: Session identifier string + + Returns: + Extracted user ID + """ + try: + # Split by '_id_' and take everything after it + parts = session_string.split('_id_') + if len(parts) > 1: + return parts[1] + + # Fallback: return original string + return session_string + + except Exception as e: + logger.warning( + f"Failed to parse user ID from session string '{session_string}': {e}" + ) + return session_string + + async def get_history( + self, + user_id: str, + apply_id: str, + group_id: str + ) -> List[dict]: + """ + Retrieve conversation history from Redis. + + Args: + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + + Returns: + List of conversation history items with Query and Answer keys + Returns empty list if no history found or on error + """ + try: + history = self.store.find_user_apply_group(user_id, apply_id, group_id) + + # Validate history structure + if not isinstance(history, list): + logger.warning( + f"Invalid history format for user {user_id}, " + f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + ) + return [] + + return history + + except Exception as e: + logger.error( + f"Failed to retrieve history for user {user_id}, " + f"apply {apply_id}, group {group_id}: {e}", + exc_info=True + ) + # Return empty list on error to allow execution to continue + return [] + + async def save_session( + self, + user_id: str, + query: str, + apply_id: str, + group_id: str, + ai_response: str + ) -> Optional[str]: + """ + Save conversation turn to Redis. + + Args: + user_id: User identifier + query: User query/message + apply_id: Application identifier + group_id: Group identifier + ai_response: AI response/answer + + Returns: + Session ID if successful, None on error + """ + try: + # Validate required fields + if not user_id: + logger.warning("Cannot save session: user_id is empty") + return None + + if not query: + logger.warning("Cannot save session: query is empty") + return None + + # Save session + session_id = self.store.save_session( + userid=user_id, + messages=query, + apply_id=apply_id, + group_id=group_id, + aimessages=ai_response + ) + + logger.info(f"Session saved successfully: {session_id}") + return session_id + + except Exception as e: + logger.error( + f"Failed to save session for user {user_id}: {e}", + exc_info=True + ) + return None + + async def cleanup_duplicates(self) -> int: + """ + Remove duplicate session entries. + + Duplicates are identified by matching: + - sessionid + - user_id (id field) + - group_id + - messages + - aimessages + + Returns: + Number of duplicate sessions deleted + """ + try: + deleted_count = self.store.delete_duplicate_sessions() + logger.info(f"Cleaned up {deleted_count} duplicate sessions") + return deleted_count + + except Exception as e: + logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True) + return 0 diff --git a/api/app/core/memory/agent/utils/template_tools.py b/api/app/core/memory/agent/utils/template_tools.py new file mode 100644 index 00000000..854c5383 --- /dev/null +++ b/api/app/core/memory/agent/utils/template_tools.py @@ -0,0 +1,117 @@ +""" +Template Service for loading and rendering Jinja2 templates. + +This service provides centralized template management with caching and error handling. +""" +# 标准库 +import os +from functools import lru_cache + +from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound + +from app.core.logging_config import get_agent_logger, log_prompt_rendering + + +logger = get_agent_logger(__name__) + + +class TemplateRenderError(Exception): + """Exception raised when template rendering fails.""" + + def __init__(self, template_name: str, error: Exception, variables: dict): + self.template_name = template_name + self.error = error + self.variables = variables + super().__init__( + f"Failed to render template '{template_name}': {str(error)}" + ) + + +class TemplateService: + """Service for loading and rendering Jinja2 templates with caching.""" + + def __init__(self, template_root: str): + """ + Initialize the template service. + + Args: + template_root: Root directory containing template files + """ + self.template_root = template_root + self.env = Environment( + loader=FileSystemLoader(template_root), + autoescape=False # Disable autoescape for prompt templates + ) + logger.info(f"TemplateService initialized with root: {template_root}") + + @lru_cache(maxsize=128) + def _load_template(self, template_name: str) -> Template: + """ + Load a template from disk with caching. + + Args: + template_name: Relative path to template file + + Returns: + Loaded Jinja2 Template object + + Raises: + TemplateNotFound: If template file doesn't exist + """ + try: + return self.env.get_template(template_name) + except TemplateNotFound as e: + expected_path = os.path.join(self.template_root, template_name) + logger.error( + f"Template not found: {template_name}. " + f"Expected path: {expected_path}" + ) + raise + + async def render_template( + self, + template_name: str, + operation_name: str, + **variables + ) -> str: + """ + Load and render a Jinja2 template. + + Args: + template_name: Relative path to template file + operation_name: Name for logging (e.g., "split_the_problem") + **variables: Template variables to render + + Returns: + Rendered template string + + Raises: + TemplateRenderError: If template loading or rendering fails + """ + try: + # Load template (cached) + template = self._load_template(template_name) + + # Render template + rendered = template.render(**variables) + + # Log rendered prompt + log_prompt_rendering(operation_name, rendered) + + return rendered + + except TemplateNotFound as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): Template not found", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) + + except Exception as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): {e}", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) diff --git a/api/app/core/memory/agent/utils/type_classifier.py b/api/app/core/memory/agent/utils/type_classifier.py index 3e5358bd..f1df6f04 100644 --- a/api/app/core/memory/agent/utils/type_classifier.py +++ b/api/app/core/memory/agent/utils/type_classifier.py @@ -1,10 +1,9 @@ """ Type classification utility for distinguishing read/write operations. """ -from app.core.config import settings from app.core.logging_config import get_agent_logger, log_prompt_rendering from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import read_template_file +from app.core.memory.agent.utils.messages_tools import read_template_file from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from jinja2 import Template diff --git a/api/app/core/memory/agent/utils/write_to_database.py b/api/app/core/memory/agent/utils/write_to_database.py deleted file mode 100644 index bd78fe9d..00000000 --- a/api/app/core/memory/agent/utils/write_to_database.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import uuid -from datetime import datetime -from typing import Any -from sqlalchemy.orm import Session -import logging -import json - -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo - -logger = logging.getLogger(__name__) - -async def write_to_database(host_id: uuid.UUID, data: Any) -> str: - """ - 将数据写入数据库 - :param host_id: 宿主 ID - :param data: 要写入的数据 - :return: 写入数据库的结果 - """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - if isinstance(data, (dict, list)): - serialized = json.dumps(data, ensure_ascii=False) - elif isinstance(data, str): - serialized = data - else: - serialized = str(data) - - new_retrieval_info = RetrievalInfo( - # host_id=host_id, - host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"), - retrieve_info=serialized, - created_at=datetime.now() - ) - db.add(new_retrieval_info) - db.commit() - logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}") - return "success to write data to database" - except Exception as e: - db.rollback() - logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index f09b35e8..53c941ad 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -7,14 +7,12 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally. import time from datetime import datetime +from dotenv import load_dotenv + from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, -) +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context @@ -23,7 +21,7 @@ from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv + load_dotenv() diff --git a/api/app/core/rag/graphrag/general/leiden.py b/api/app/core/rag/graphrag/general/leiden.py index 8238f0f1..a6cf47d0 100644 --- a/api/app/core/rag/graphrag/general/leiden.py +++ b/api/app/core/rag/graphrag/general/leiden.py @@ -8,11 +8,12 @@ Reference: import logging import html from typing import Any, cast -from graspologic.partition import hierarchical_leiden -from graspologic.utils import largest_connected_component +# from graspologic.partition import hierarchical_leiden +# from graspologic.utils import largest_connected_component import networkx as nx from networkx import is_empty - +hierarchical_leiden='' +largest_connected_component='' def _stabilize_graph(graph: nx.Graph) -> nx.Graph: """Ensure an undirected graph with the same relationships will always be read the same way.""" diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 2d78d796..ccebdcd6 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -13,26 +13,26 @@ from threading import Lock from typing import Any, AsyncGenerator, Dict, List, Optional import redis +from langchain_core.messages import HumanMessage + from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config +from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_mcp_adapters.tools import load_mcp_tools from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -55,18 +55,16 @@ class MemoryAgentService: self.user_locks: Dict[str, Lock] = {} self.locks_lock = Lock() - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): - messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') - countext = re.findall(r'"status": "(.*?)",', messages)[0] + def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): duration = time.time() - start_time - if countext == 'success': + if str(messages) == 'success': logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, duration=duration, details={"message_length": len(message)}) - return countext + return context else: logger.warning(f"Write operation failed for group {group_id}") @@ -150,8 +148,26 @@ class MemoryAgentService: else: status = "unknown" + # Add database connection pool status + try: + from app.db import get_pool_status + pool_status = get_pool_status() + logger.info(f"Database pool status: {pool_status}") + + # Check if pool usage is too high + if pool_status.get("usage_percent", 0) > 80: + logger.warning(f"High database pool usage: {pool_status['usage_percent']}%") + status = "warning" + + except Exception as e: + logger.error(f"Failed to get pool status: {e}") + pool_status = {"error": str(e)} + logger.info(f"Health status: {status}") - return {"status": status} + return { + "status": status, + "database_pool": pool_status + } def get_log_content(self) -> str: """ @@ -308,54 +324,42 @@ class MemoryAgentService: audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) raise ValueError(error_msg) - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) - return result - else: - async with client.session("data_flow") as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - workflow_errors = [] # Track errors from workflow - - # Pass memory_config to the graph workflow - async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph: - logger.debug("Write graph created successfully") + try: + if storage_type == "rag": + result = await write_rag(group_id, message, user_rag_memory_id) + return result + else: + async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, + "memory_config": memory_config} - async for event in graph.astream( - {"messages": message, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.error(f"Write workflow failed with errors: {error_details}") - - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_details - ) - - raise ValueError(f"Write workflow failed: {error_details}") - - return self.writer_messages_deal(messages, start_time, group_id, config_id, message) - + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Write operation failed: {str(e)}" + logger.error(error_msg) + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + raise ValueError(error_msg) + + + + async def read_memory( self, group_id: str, @@ -394,8 +398,9 @@ class MemoryAgentService: import time start_time = time.time() - ori_message=message end_user_id=group_id + ori_message=message + # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -408,15 +413,15 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + logger.info(f"Read operation for group {group_id} with config_id {config_id}") - + # 导入审计日志记录器 try: from app.core.memory.utils.log.audit_logger import audit_logger except ImportError: audit_logger = None - + # Get group lock to prevent concurrent processing group_lock = self.get_group_lock(group_id) @@ -432,7 +437,7 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time @@ -444,305 +449,133 @@ class MemoryAgentService: duration=duration, error=error_msg ) - + raise ValueError(error_msg) - + # Step 2: Prepare history history.append({"role": "user", "content": message}) logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") - + # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() + try: + async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "group_id": group_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + for node_name, node_data in update_event.items(): - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") - - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found + optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + result = reorder_output_results(optimized_outputs) - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass + # Log successful operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=duration + ) - # 记录成功的操作 - total_duration = time.time() - start_time - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") + retrieved_content = [] + repo = ShortTermMemoryRepository(db) + if str(search_switch).strip() != "2": + for intermediate in result: + intermediate_type = intermediate['type'] + if intermediate_type == "search_result": + query = intermediate['query'] + raw_results = intermediate['raw_results'] + reranked_results = raw_results.get('reranked_results', []) + try: + statements = [statement['statement'] for statement in + reranked_results.get('statements', [])] + except Exception: + statements = [] + statements = list(set(statements)) + retrieved_content.append({query: statements}) + if retrieved_content == []: + retrieved_content = '' + if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # and retrieved_content!=[] + # 使用 upsert 方法 + repo.upsert( + end_user_id=end_user_id, # 确保这个变量在作用域内 + messages=ori_message, + aimessages=summary, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + print("写入成功") + return { + "answer": summary, + "intermediate_outputs": result + } + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Read operation failed: {str(e)}" + logger.error(error_msg) if audit_logger: + duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } + duration=duration, + error=error_msg ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=True, - duration=total_duration, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) - } - ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") - - - - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } + raise ValueError(error_msg) - def _create_intermediate_key(self, output: Dict) -> str: - """ - Create a unique key for an intermediate output to detect duplicates. - - Args: - output: Intermediate output dictionary - - Returns: - Unique string key for this output - """ - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - # Use type + original query as key - return f"split:{output.get('original_query', '')}" - elif output_type == 'problem_extension': - # Use type + original query as key - return f"extension:{output.get('original_query', '')}" - elif output_type == 'search_result': - # Use type + query + index as key - return f"search:{output.get('query', '')}:{output.get('index', 0)}" - elif output_type == 'retrieval_summary': - # Use type + query as key - return f"summary:{output.get('query', '')}" - elif output_type == 'verification': - # Use type + query as key - return f"verification:{output.get('query', '')}" - elif output_type == 'input_summary': - # Use type + query as key - return f"input_summary:{output.get('query', '')}" - else: - # Fallback: use JSON representation - import json - return json.dumps(output, sort_keys=True) - - def _format_intermediate_output(self, output: Dict) -> Dict: - """Format intermediate output for frontend display.""" - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - return { - 'type': 'problem_split', - 'title': '问题拆分', - 'data': output.get('data', []), - 'original_query': output.get('original_query', '') - } - elif output_type == 'problem_extension': - return { - 'type': 'problem_extension', - 'title': '问题扩展', - 'data': output.get('data', {}), - 'original_query': output.get('original_query', '') - } - elif output_type == 'search_result': - return { - 'type': 'search_result', - 'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})', - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results', ''), - 'index': output.get('index', 0), - 'total': output.get('total', 0) - } - elif output_type == 'retrieval_summary': - return { - 'type': 'retrieval_summary', - 'title': '检索总结', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - } - elif output_type == 'verification': - return { - 'type': 'verification', - 'title': '数据验证', - 'result': output.get('result', 'unknown'), - 'reason': output.get('reason', ''), - 'query': output.get('query', ''), - 'verified_count': output.get('verified_count', 0) - } - elif output_type == 'input_summary': - return { - 'type': 'input_summary', - 'title': '快速答案', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - else: - return output async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ @@ -850,6 +683,7 @@ class MemoryAgentService: # 获取当前空间下的所有宿主 from app.repositories import app_repository, end_user_repository from app.schemas.app_schema import App as AppSchema + from app.schemas.end_user_schema import EndUser as EndUserSchema # 查询应用并转换为 Pydantic 模型 apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) @@ -1147,43 +981,6 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic -# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]: -# """ -# Parse and return API documentation - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Dict containing parsed API documentation or error information -# """ -# try: -# target = file_path or get_default_docs_path() - -# if not os.path.isfile(target): -# return { -# "success": False, -# "msg": "API文档文件不存在", -# "error_code": "DOC_NOT_FOUND", -# "data": {"path": target} -# } - -# data = parse_api_docs(target) -# return { -# "success": True, -# "msg": "解析成功", -# "data": data -# } -# except Exception as e: -# logger.error(f"Failed to parse API docs: {e}") -# return { -# "success": False, -# "msg": "解析失败", -# "error_code": "DOC_PARSE_ERROR", -# "data": {"error": str(e)} -# } - - def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: """ 获取终端用户关联的记忆配置 @@ -1192,20 +989,18 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An 1. 根据 end_user_id 获取用户的 app_id 2. 获取该应用的最新发布版本 3. 从发布版本的 config 字段中提取 memory_config_id - 4. 根据 memory_config_id 查询配置名称 Args: end_user_id: 终端用户ID db: 数据库会话 Returns: - 包含 memory_config_id、config_name 和相关信息的字典 + 包含 memory_config_id 和相关信息的字典 Raises: ValueError: 当终端用户不存在或应用未发布时 """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser from sqlalchemy import select @@ -1239,31 +1034,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - # 4. 根据 memory_config_id 查询配置名称 - config_name = None - if memory_config_id: - try: - # memory_config_id 可能是整数或字符串,需要转换 - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - if data_config: - config_name = data_config.config_name - logger.debug(f"Found config_name: {config_name} for config_id: {config_id}") - else: - logger.warning(f"DataConfig not found for config_id: {config_id}") - except (ValueError, TypeError) as e: - logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}") - result = { "end_user_id": str(end_user_id), "app_id": str(app_id), "release_id": str(latest_release.id), "release_version": latest_release.version, - "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_id": memory_config_id } - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}") + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") return result @@ -1271,126 +1050,112 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) """ 批量获取多个终端用户关联的记忆配置 - 通过优化的查询减少数据库往返次数: - 1. 一次性查询所有 end_user 及其 app_id - 2. 批量查询所有相关的 app_release - 3. 批量查询所有相关的 data_config + 通过以下流程获取配置: + 1. 批量查询所有 end_user 及其 app_id + 2. 批量获取所有应用的最新发布版本 + 3. 从发布版本的 config 字段中提取 memory_config_id 和 memory_config_name Args: end_user_ids: 终端用户ID列表 db: 数据库会话 Returns: - 字典,key 为 end_user_id,value 为配置信息字典 - 对于查询失败的用户,value 包含 error 字段 + 字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典 + 格式: { + "user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"}, + "user_id_2": {"memory_config_id": None, "memory_config_name": None}, + ... + } """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser + from app.models.memory_config_model import MemoryConfig from sqlalchemy import select - logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users") + logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") result = {} # 1. 批量查询所有 end_user 及其 app_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - # 构建 end_user_id -> end_user 的映射 - end_user_map = {str(user.id): user for user in end_users} + # 创建 end_user_id 到 app_id 的映射 + user_to_app = {str(eu.id): eu.app_id for eu in end_users} - # 记录不存在的用户 - for user_id in end_user_ids: - if user_id not in end_user_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"终端用户不存在: {user_id}" - } + # 获取所有相关的 app_id + app_ids = list(set(user_to_app.values())) - if not end_users: - logger.warning("No valid end users found") + if not app_ids: + logger.warning("No valid app_ids found for the provided end_user_ids") + # 返回空配置 + for user_id in end_user_ids: + result[user_id] = {"memory_config_id": None, "memory_config_name": None} return result - # 2. 批量查询所有相关应用的最新发布版本 - app_ids = [user.app_id for user in end_users] - + # 2. 批量获取所有应用的最新发布版本 # 使用子查询找到每个 app 的最新版本 - from sqlalchemy import and_ + from sqlalchemy import func - # 查询所有相关的活跃发布版本 - releases = db.query(AppRelease).filter( - and_( - AppRelease.app_id.in_(app_ids), - AppRelease.is_active.is_(True) + subq = ( + select( + AppRelease.app_id, + func.max(AppRelease.version).label('max_version') ) - ).order_by(AppRelease.app_id, AppRelease.version.desc()).all() + .where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True)) + .group_by(AppRelease.app_id) + .subquery() + ) - # 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本) - app_release_map = {} - for release in releases: - app_id_str = str(release.app_id) - if app_id_str not in app_release_map: - app_release_map[app_id_str] = release + stmt = ( + select(AppRelease) + .join( + subq, + (AppRelease.app_id == subq.c.app_id) & (AppRelease.version == subq.c.max_version) + ) + .where(AppRelease.is_active.is_(True)) + ) - # 3. 收集所有 memory_config_id + latest_releases = db.scalars(stmt).all() + + # 创建 app_id 到 release 的映射 + app_to_release = {str(release.app_id): release for release in latest_releases} + + # 3. 提取所有 memory_config_id memory_config_ids = [] - for release in app_release_map.values(): + for release in latest_releases: config = release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - memory_config_ids.append(config_id) - except (ValueError, TypeError): - pass + memory_config_ids.append(memory_config_id) - # 4. 批量查询所有 data_config - config_name_map = {} + # 4. 批量查询 memory_config_name + memory_configs = {} if memory_config_ids: - data_configs = db.query(DataConfig).filter( - DataConfig.config_id.in_(memory_config_ids) - ).all() - config_name_map = {config.config_id: config.config_name for config in data_configs} + configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() + memory_configs = {str(cfg.id): cfg.config_name for cfg in configs} # 5. 组装结果 - for user in end_users: - user_id = str(user.id) - app_id = str(user.app_id) - - # 检查是否有发布版本 - if app_id not in app_release_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"应用未发布: {app_id}" - } + for user_id in end_user_ids: + app_id = user_to_app.get(user_id) + if not app_id: + result[user_id] = {"memory_config_id": None, "memory_config_name": None} continue - release = app_release_map[app_id] + release = app_to_release.get(str(app_id)) + if not release: + result[user_id] = {"memory_config_id": None, "memory_config_name": None} + continue - # 提取 memory_config_id config = release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - - # 获取 config_name - config_name = None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - config_name = config_name_map.get(config_id) - except (ValueError, TypeError): - pass + memory_config_name = memory_configs.get(memory_config_id) if memory_config_id else None result[user_id] = { - "end_user_id": user_id, "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_name": memory_config_name } - logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}") + logger.info(f"Successfully retrieved {len(result)} connected configs") return result \ No newline at end of file diff --git a/api/clear_celery_queue.py b/api/clear_celery_queue.py new file mode 100644 index 00000000..e72f7475 --- /dev/null +++ b/api/clear_celery_queue.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +清空 Celery 队列中的所有消息 + +这个脚本会删除 Redis 中 Celery 队列的所有待处理任务 +""" + +import redis +from app.core.config import settings +from app.celery_app import celery_app + + +def clear_celery_queue(): + """清空 Celery 队列""" + print("🗑️ 清空 Celery 队列") + print("=" * 50) + + try: + # 连接到 Redis + redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + password=settings.REDIS_PASSWORD, + db=settings.CELERY_BROKER, + decode_responses=True + ) + + # 测试连接 + redis_client.ping() + print("✅ Redis 连接成功") + + # 队列名称 + queue_name = 'localhost_test_wyl' + + # 获取队列长度 + queue_length = redis_client.llen(queue_name) + print(f"📊 队列 '{queue_name}' 当前长度: {queue_length}") + + if queue_length == 0: + print("✅ 队列已经是空的,无需清理") + return + + # 确认清空 + print(f"\n⚠️ 警告: 即将删除 {queue_length} 个待处理任务") + confirm = input("确认清空队列? (yes/no): ").strip().lower() + + if confirm not in ['yes', 'y']: + print("❌ 操作已取消") + return + + # 删除队列 + deleted_count = redis_client.delete(queue_name) + print(f"✅ 已删除队列,删除了 {deleted_count} 个键") + + # 验证队列已清空 + new_length = redis_client.llen(queue_name) + print(f"📊 队列 '{queue_name}' 新长度: {new_length}") + + if new_length == 0: + print("✅ 队列已成功清空!") + else: + print(f"⚠️ 队列仍有 {new_length} 个任务") + + # 清理结果后端(可选) + print("\n🧹 清理结果后端...") + result_keys = redis_client.keys("celery-task-meta-*") + if result_keys: + deleted_results = redis_client.delete(*result_keys) + print(f"✅ 删除了 {deleted_results} 个任务结果") + else: + print("✅ 没有待清理的任务结果") + + except redis.ConnectionError as e: + print(f"❌ Redis 连接失败: {e}") + except Exception as e: + print(f"❌ 清空队列失败: {e}") + import traceback + traceback.print_exc() + + +def clear_all_celery_data(): + """清空所有 Celery 相关数据(包括结果)""" + print("\n🗑️ 清空所有 Celery 数据") + print("=" * 50) + + try: + # 连接到 Redis + redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + password=settings.REDIS_PASSWORD, + db=settings.CELERY_BROKER, + decode_responses=True + ) + + # 获取所有 Celery 相关的键 + all_keys = redis_client.keys("*") + celery_keys = [k for k in all_keys if 'celery' in k.lower() or 'localhost_test_wyl' in k] + + print(f"📊 找到 {len(celery_keys)} 个 Celery 相关的键") + + if not celery_keys: + print("✅ 没有 Celery 数据需要清理") + return + + # 显示键列表 + print("\n📋 Celery 相关的键:") + for key in celery_keys[:10]: # 只显示前10个 + print(f" - {key}") + if len(celery_keys) > 10: + print(f" ... 还有 {len(celery_keys) - 10} 个键") + + # 确认清空 + print(f"\n⚠️ 警告: 即将删除 {len(celery_keys)} 个 Celery 相关的键") + confirm = input("确认清空所有 Celery 数据? (yes/no): ").strip().lower() + + if confirm not in ['yes', 'y']: + print("❌ 操作已取消") + return + + # 删除所有键 + if celery_keys: + deleted_count = redis_client.delete(*celery_keys) + print(f"✅ 已删除 {deleted_count} 个键") + + print("✅ 所有 Celery 数据已清空!") + + except Exception as e: + print(f"❌ 清空失败: {e}") + import traceback + traceback.print_exc() + + +def show_queue_info(): + """显示队列信息""" + print("\n📊 队列信息") + print("=" * 50) + + try: + # 连接到 Redis + redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + password=settings.REDIS_PASSWORD, + db=settings.CELERY_BROKER, + decode_responses=True + ) + + # 队列名称 + queue_name = 'localhost_test_wyl' + + # 获取队列信息 + queue_length = redis_client.llen(queue_name) + print(f"📊 队列 '{queue_name}' 长度: {queue_length}") + + # 获取结果数量 + result_keys = redis_client.keys("celery-task-meta-*") + print(f"📊 任务结果数量: {len(result_keys)}") + + # 获取所有 Celery 键 + all_keys = redis_client.keys("*") + celery_keys = [k for k in all_keys if 'celery' in k.lower() or 'localhost_test_wyl' in k] + print(f"📊 Celery 相关键总数: {len(celery_keys)}") + + except Exception as e: + print(f"❌ 获取队列信息失败: {e}") + + +def main(): + """主函数""" + print("🚀 Celery 队列清理工具") + print("=" * 50) + + while True: + print("\n请选择操作:") + print("1. 查看队列信息") + print("2. 清空队列(只删除待处理任务)") + print("3. 清空所有 Celery 数据(包括结果)") + print("4. 退出") + + choice = input("\n请输入选项 (1-4): ").strip() + + if choice == '1': + show_queue_info() + elif choice == '2': + clear_celery_queue() + elif choice == '3': + clear_all_celery_data() + elif choice == '4': + print("👋 再见!") + break + else: + print("❌ 无效选项,请重新选择") + + +if __name__ == "__main__": + main() \ No newline at end of file