diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 5a32372a..f827eaaf 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,3 +1,19 @@ +""" +Memory Reflection Controller + +This module provides REST API endpoints for managing memory reflection configurations +and operations. It handles reflection engine setup, configuration management, and +execution of self-reflection processes across memory systems. + +Key Features: +- Reflection configuration management (save, retrieve, update) +- Workspace-wide reflection execution across multiple applications +- Individual configuration-based reflection runs +- Multi-language support for reflection outputs +- Integration with Neo4j memory storage and LLM models +- Comprehensive error handling and logging +""" + import asyncio import time import uuid @@ -28,9 +44,13 @@ from sqlalchemy.orm import Session from app.utils.config_utils import resolve_config_id +# Load environment variables for configuration load_dotenv() + +# Initialize API logger for request tracking and debugging api_logger = get_api_logger() +# Configure router with prefix and tags for API organization router = APIRouter( prefix="/memory", tags=["Memory"], @@ -43,7 +63,38 @@ async def save_reflection_config( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Save reflection configuration to data_comfig table""" + """ + Save reflection configuration to memory config table + + Persists reflection engine configuration settings to the data_config table, + including reflection parameters, model settings, and evaluation criteria. + Validates configuration parameters and ensures data consistency. + + Args: + request: Memory reflection configuration data including: + - config_id: Configuration identifier to update + - reflection_enabled: Whether reflection is enabled + - reflection_period_in_hours: Reflection execution interval + - reflexion_range: Scope of reflection (partial/all) + - baseline: Reflection strategy (time/fact/hybrid) + - reflection_model_id: LLM model for reflection operations + - memory_verify: Enable memory verification checks + - quality_assessment: Enable quality assessment evaluation + current_user: Authenticated user saving the configuration + db: Database session for data operations + + Returns: + dict: Success response with saved reflection configuration data + + Raises: + HTTPException 400: If config_id is missing or parameters are invalid + HTTPException 500: If configuration save operation fails + + Database Operations: + - Updates memory_config table with reflection settings + - Commits transaction and refreshes entity + - Maintains configuration consistency + """ try: config_id = request.config_id config_id = resolve_config_id(config_id, db) @@ -54,6 +105,7 @@ async def save_reflection_config( ) api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") + # Update reflection configuration in database memory_config = MemoryConfigRepository.update_reflection_config( db, config_id=config_id, @@ -66,6 +118,7 @@ async def save_reflection_config( quality_assessment=request.quality_assessment ) + # Commit transaction and refresh entity db.commit() db.refresh(memory_config) @@ -102,13 +155,55 @@ async def start_workspace_reflection( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """启动工作空间中所有匹配应用的反思功能""" + """ + Start reflection functionality for all matching applications in workspace + + Initiates reflection processes across all applications within the user's current + workspace that have valid memory configurations. Processes each application's + configurations and associated end users, executing reflection operations + with proper error isolation and transaction management. + + This endpoint serves as a workspace-wide reflection orchestrator, ensuring + that reflection failures for individual users don't affect other operations. + + Args: + current_user: Authenticated user initiating workspace reflection + db: Database session for configuration queries + + Returns: + dict: Success response with reflection results for all processed applications: + - app_id: Application identifier + - config_id: Memory configuration identifier + - end_user_id: End user identifier + - reflection_result: Individual reflection operation result + + Processing Logic: + 1. Retrieve all applications in the current workspace + 2. Filter applications with valid memory configurations + 3. For each configuration, find matching releases + 4. Execute reflection for each end user with isolated transactions + 5. Aggregate results with error handling per user + + Error Handling: + - Individual user reflection failures are isolated + - Failed operations are logged and included in results + - Database transactions are isolated per user to prevent cascading failures + - Comprehensive error reporting for debugging + + Raises: + HTTPException 500: If workspace reflection initialization fails + + Performance Notes: + - Uses independent database sessions for each user operation + - Prevents transaction failures from affecting other users + - Comprehensive logging for operation tracking + """ workspace_id = current_user.current_workspace_id try: api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}") - # 使用独立的数据库会话来获取工作空间应用详情,避免事务失败 + # Use independent database session to get workspace app details, avoiding transaction failures from app.db import get_db_context with get_db_context() as query_db: service = WorkspaceAppService(query_db) @@ -116,8 +211,9 @@ async def start_workspace_reflection( reflection_results = [] + # Process each application in the workspace for data in result['apps_detailed_info']: - # 跳过没有配置的应用 + # Skip applications without configurations if not data['memory_configs']: api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过") continue @@ -126,22 +222,22 @@ async def start_workspace_reflection( memory_configs = data['memory_configs'] end_users = data['end_users'] - # 为每个配置和用户组合执行反思 + # Execute reflection for each configuration and user combination for config in memory_configs: config_id_str = str(config['config_id']) - # 找到匹配此配置的所有release + # Find all releases matching this configuration matching_releases = [r for r in releases if str(r['config']) == config_id_str] if not matching_releases: api_logger.debug(f"配置 {config_id_str} 没有匹配的release") continue - # 为每个用户执行反思 - 使用独立的数据库会话 + # Execute reflection for each user - using independent database sessions for user in end_users: api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}") - # 为每个用户创建独立的数据库会话,避免事务失败影响其他用户 + # Create independent database session for each user to avoid transaction failure impact with get_db_context() as user_db: try: reflection_service = MemoryReflectionService(user_db) @@ -184,14 +280,51 @@ async def start_reflection_configs( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """通过config_id查询memory_config表中的反思配置信息""" + """ + Query reflection configuration information by config_id + + Retrieves detailed reflection configuration settings from the memory_config + table for a specific configuration ID. Provides comprehensive reflection + parameters including model settings, evaluation criteria, and operational flags. + + Args: + config_id: Configuration identifier (UUID or integer) to query + current_user: Authenticated user making the request + db: Database session for data operations + + Returns: + dict: Success response with detailed reflection configuration: + - config_id: Resolved configuration identifier + - reflection_enabled: Whether reflection is enabled for this config + - reflection_period_in_hours: Reflection execution interval + - reflexion_range: Scope of reflection operations (partial/all) + - baseline: Reflection strategy (time/fact/hybrid) + - reflection_model_id: LLM model identifier for reflection + - memory_verify: Memory verification flag + - quality_assessment: Quality assessment flag + + Database Operations: + - Queries memory_config table by resolved config_id + - Retrieves all reflection-related configuration fields + - Resolves configuration ID for consistent formatting + + Raises: + HTTPException 404: If configuration with specified ID is not found + HTTPException 500: If configuration query operation fails + + ID Resolution: + - Supports both UUID and integer config_id formats + - Automatically resolves to appropriate internal format + - Maintains consistency across different ID representations + """ config_id = resolve_config_id(config_id, db) try: config_id=resolve_config_id(config_id,db) api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) memory_config_id = resolve_config_id(result.config_id, db) - # 构建返回数据 + + # Build response data with comprehensive configuration details reflection_config = { "config_id": memory_config_id, "reflection_enabled": result.enable_self_reflexion, @@ -204,10 +337,12 @@ async def start_reflection_configs( } api_logger.info(f"成功查询反思配置,config_id: {config_id}") return success(data=reflection_config, msg="反思配置查询成功") - + api_logger.info(f"Successfully queried reflection config, config_id: {config_id}") + return success(data=reflection_config, msg="Reflection configuration query successful") + except HTTPException: - # 重新抛出HTTP异常 + # Re-raise HTTP exceptions without modification raise except Exception as e: api_logger.error(f"查询反思配置失败: {str(e)}") @@ -223,13 +358,66 @@ async def reflection_run( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Activate the reflection function for all matching applications in the workspace""" - # 使用集中化的语言校验 + """ + Execute reflection engine with specified configuration + + Runs the reflection engine using configuration parameters from the database. + Validates model availability, sets up the reflection engine with proper + configuration, and executes the reflection process with multi-language support. + + This endpoint provides a test run capability for reflection configurations, + allowing users to validate their reflection settings and see results before + deploying to production environments. + + Args: + config_id: Configuration identifier (UUID or integer) for reflection settings + language_type: Language preference header for output localization (optional) + current_user: Authenticated user executing the reflection + db: Database session for configuration queries + + Returns: + dict: Success response with reflection execution results including: + - baseline: Reflection strategy used + - source_data: Input data processed + - memory_verifies: Memory verification results (if enabled) + - quality_assessments: Quality assessment results (if enabled) + - reflexion_data: Generated reflection insights and solutions + + Configuration Validation: + - Verifies configuration exists in database + - Validates LLM model availability + - Falls back to default model if specified model is unavailable + - Ensures all required parameters are properly set + + Reflection Engine Setup: + - Creates ReflectionConfig with database parameters + - Initializes Neo4j connector for memory access + - Sets up ReflectionEngine with validated model + - Configures language preferences for output + + Error Handling: + - Model validation with fallback to default + - Configuration validation and error reporting + - Comprehensive logging for debugging + - Graceful handling of missing configurations + + Raises: + HTTPException 404: If configuration is not found + HTTPException 500: If reflection execution fails + + Performance Notes: + - Direct database query for configuration retrieval + - Model validation to prevent runtime failures + - Efficient reflection engine initialization + - Language-aware output processing + """ + # Use centralized language validation for consistent localization language = get_language_from_header(language_type) api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") config_id = resolve_config_id(config_id, db) - # 使用MemoryConfigRepository查询反思配置 + + # Query reflection configuration using MemoryConfigRepository result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( @@ -239,7 +427,7 @@ async def reflection_run( api_logger.info(f"成功查询反思配置,config_id: {config_id}") - # 验证模型ID是否存在 + # Validate model ID existence model_id = result.reflection_model_id if model_id: try: @@ -250,6 +438,7 @@ async def reflection_run( # 可以设置为None,让反思引擎使用默认模型 model_id = None + # Create reflection configuration with database parameters config = ReflectionConfig( enabled=result.enable_self_reflexion, iteration_period=result.iteration_period, @@ -262,11 +451,13 @@ async def reflection_run( model_id=model_id, language_type=language_type ) + + # Initialize Neo4j connector and reflection engine connector = Neo4jConnector() engine = ReflectionEngine( config=config, neo4j_connector=connector, - llm_client=model_id # 传入验证后的 model_id + llm_client=model_id # Pass validated model_id ) result=await (engine.reflection_run()) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 0acac6ce..b69406a8 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -1,3 +1,18 @@ +""" +Memory Short Term Controller + +This module provides REST API endpoints for managing short-term and long-term memory +data retrieval and analysis. It handles memory system statistics, data aggregation, +and provides comprehensive memory insights for end users. + +Key Features: +- Short-term memory data retrieval and statistics +- Long-term memory data aggregation +- Entity count integration +- Multi-language response support +- Memory system analytics and reporting +""" + from typing import Optional from dotenv import load_dotenv @@ -13,9 +28,13 @@ from app.models.user_model import User from app.services.memory_short_service import LongService, ShortService from app.services.memory_storage_service import search_entity +# Load environment variables for configuration load_dotenv() + +# Initialize API logger for request tracking and debugging api_logger = get_api_logger() +# Configure router with prefix and tags for API organization router = APIRouter( prefix="/memory/short", tags=["Memory"], @@ -27,24 +46,73 @@ async def short_term_configs( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): - # 使用集中化的语言校验 + """ + Retrieve comprehensive short-term and long-term memory statistics + + Provides a comprehensive overview of memory system data for a specific end user, + including short-term memory entries, long-term memory aggregations, entity counts, + and retrieval statistics. Supports multi-language responses based on request headers. + + This endpoint serves as a central dashboard for memory system analytics, combining + data from multiple memory subsystems to provide a holistic view of user memory state. + + Args: + end_user_id: Unique identifier for the end user whose memory data to retrieve + language_type: Language preference header for response localization (optional) + current_user: Authenticated user making the request (injected by dependency) + db: Database session for data operations (injected by dependency) + + Returns: + dict: Success response containing comprehensive memory statistics: + - short_term: List of short-term memory entries with detailed data + - long_term: List of long-term memory aggregations and summaries + - entity: Count of entities associated with the end user + - retrieval_number: Total count of short-term memory retrievals + - long_term_number: Total count of long-term memory entries + + Response Structure: + { + "code": 200, + "msg": "Short-term memory system data retrieved successfully", + "data": { + "short_term": [...], # Short-term memory entries + "long_term": [...], # Long-term memory data + "entity": 42, # Entity count + "retrieval_number": 156, # Short-term retrieval count + "long_term_number": 23 # Long-term memory count + } + } + + Raises: + HTTPException: If end_user_id is invalid or data retrieval fails + + Performance Notes: + - Combines multiple service calls for comprehensive data + - Entity search is performed asynchronously for better performance + - Response time depends on memory data volume for the specified user + """ + # Use centralized language validation for consistent localization language = get_language_from_header(language_type) - # 获取短期记忆数据 - short_term=ShortService(end_user_id, db) - short_result=short_term.get_short_databasets() - short_count=short_term.get_short_count() + # Retrieve short-term memory data and statistics + short_term = ShortService(end_user_id, db) + short_result = short_term.get_short_databasets() # Get short-term memory entries + short_count = short_term.get_short_count() # Get short-term retrieval count - long_term=LongService(end_user_id, db) - long_result=long_term.get_long_databasets() + # Retrieve long-term memory data and aggregations + long_term = LongService(end_user_id, db) + long_result = long_term.get_long_databasets() # Get long-term memory entries + # Get entity count for the specified end user entity_result = await search_entity(end_user_id) + + # Compile comprehensive memory statistics response result = { - 'short_term': short_result, - 'long_term': long_result, - 'entity': entity_result.get('num', 0), - "retrieval_number":short_count, - "long_term_number":len(long_result) + 'short_term': short_result, # Short-term memory entries + 'long_term': long_result, # Long-term memory data + 'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found) + "retrieval_number": short_count, # Short-term retrieval statistics + "long_term_number": len(long_result) # Long-term memory entry count } return success(data=result, msg="短期记忆系统数据获取成功") \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py index 6595a2ce..829f26c4 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -2,15 +2,36 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState def content_input_node(state: ReadState) -> ReadState: - """开始节点 - 提取内容并保持状态信息""" + """ + Start node - Extract content and maintain state information + + Extracts the content from the first message in the state and returns it + as the data field while preserving all other state information. + + Args: + state: ReadState containing messages and other state data + + Returns: + ReadState: Updated state with extracted content in data field + """ content = state['messages'][0].content if state.get('messages') else '' - # 返回内容并保持所有状态信息 + # Return content and maintain all state information return {"data": content} def content_input_write(state: WriteState) -> WriteState: - """开始节点 - 提取内容并保持状态信息""" + """ + Start node - Extract content and maintain state information for write operations + + Extracts the content from the first message in the state for write operations. + + Args: + state: WriteState containing messages and other state data + + Returns: + WriteState: Updated state with extracted content in data field + """ content = state['messages'][0].content if state.get('messages') else '' - # 返回内容并保持所有状态信息 + # Return content and maintain all state information return {"data": content} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 784e5802..3030669c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -19,19 +19,39 @@ logger = get_agent_logger(__name__) class ProblemNodeService(LLMServiceMixin): - """问题处理节点服务类""" + """ + Problem processing node service class + + Handles problem decomposition and extension operations using LLM services. + Inherits from LLMServiceMixin to provide structured LLM calling capabilities. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance problem_service = ProblemNodeService() async def Split_The_Problem(state: ReadState) -> ReadState: - """问题分解节点""" + """ + Problem decomposition node + + Breaks down complex user queries into smaller, more manageable sub-problems. + Uses LLM to analyze the input and generate structured problem decomposition + with question types and reasoning. + + Args: + state: ReadState containing user input and configuration + + Returns: + ReadState: Updated state with problem decomposition results + """ # 从状态中获取数据 content = state.get('data', '') end_user_id = state.get('end_user_id', '') @@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") - # 验证结构化响应 + # Validate structured response if not structured or not hasattr(structured, 'root'): logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") split_result = json.dumps([], ensure_ascii=False) @@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: exc_info=True ) - # 提供更详细的错误信息 + # Provide more detailed error information error_details = { "error_type": type(e).__name__, "error_message": str(e), @@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: logger.error(f"Split_The_Problem error details: {error_details}") - # 创建默认的空结果 + # Create default empty result result = { "context": json.dumps([], ensure_ascii=False), "original": content, @@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState: } } - # 返回更新后的状态,包含spit_context字段 + # Return updated state including spit_context field return {"spit_data": result} async def Problem_Extension(state: ReadState) -> ReadState: - """问题扩展节点""" - # 获取原始数据和分解结果 + """ + Problem extension node + + Extends the decomposed problems from Split_The_Problem node by generating + additional related questions and organizing them by original question. + Uses LLM to create comprehensive question extensions for better memory retrieval. + + Args: + state: ReadState containing decomposed problems and configuration + + Returns: + ReadState: Updated state with extended problem results + """ + # Get original data and decomposition results start = time.time() content = state.get('data', '') data = state.get('spit_data', '')['context'] @@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") - # 验证结构化响应 + # Validate structured response if not response_content or not hasattr(response_content, 'root'): logger.warning("Problem_Extension: 结构化响应为空或格式不正确") aggregated_dict = {} @@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: exc_info=True ) - # 提供更详细的错误信息 + # Provide more detailed error information error_details = { "error_type": type(e).__name__, "error_message": str(e), diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 06539ad1..f2cd0d3d 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -29,6 +29,18 @@ logger = get_agent_logger(__name__) async def rag_config(state): + """ + Configure RAG (Retrieval-Augmented Generation) settings + + Creates configuration for knowledge base retrieval including similarity thresholds, + weights, and reranker settings. + + Args: + state: Current state containing user_rag_memory_id + + Returns: + dict: RAG configuration dictionary + """ user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { "knowledge_bases": [ @@ -48,6 +60,19 @@ async def rag_config(state): async def rag_knowledge(state, question): + """ + Retrieve knowledge using RAG approach + + Performs knowledge retrieval from configured knowledge bases using the + provided question and returns formatted results. + + Args: + state: Current state containing configuration + question: Question to search for + + Returns: + tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results) + """ kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') user_rag_memory_id = state.get("user_rag_memory_id", '') @@ -68,12 +93,24 @@ async def rag_knowledge(state, question): async def llm_infomation(state: ReadState) -> ReadState: + """ + Get LLM configuration information from state + + Retrieves model configuration details including model ID and tenant ID + from the memory configuration in the current state. + + Args: + state: ReadState containing memory configuration + + Returns: + ReadState: Model configuration as Pydantic model + """ memory_config = state.get('memory_config', None) model_id = memory_config.llm_model_id tenant_id = memory_config.tenant_id - # 使用现有的 memory_config 而不是重新查询数据库 - # 或者使用线程安全的数据库访问 + # Use existing memory_config instead of re-querying database + # or use thread-safe database access with get_db_context() as db: result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) result_pydantic = model_schema.ModelConfig.model_validate(result_orm) @@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState: async def clean_databases(data) -> str: """ - 简化的数据库搜索结果清理函数 + Simplified database search result cleaning function + + Processes and cleans search results from various sources including + reranked results and time-based search results. Extracts text content + from structured data and returns as formatted string. Args: - data: 搜索结果数据 + data: Search result data (can be string, dict, or other types) Returns: - 清理后的内容字符串 + str: Cleaned content string """ try: - # 解析JSON字符串 + # Parse JSON string if isinstance(data, str): try: data = json.loads(data) @@ -101,24 +142,24 @@ async def clean_databases(data) -> str: if not isinstance(data, dict): return str(data) - # 获取结果数据 + # Get result data # with open("搜索结果.json","w",encoding='utf-8') as f: # f.write(json.dumps(data, indent=4, ensure_ascii=False)) results = data.get('results', data) if not isinstance(results, dict): return str(results) - # 收集所有内容 + # Collect all content content_list = [] - # 处理重排序结果 + # Process reranked results reranked = results.get('reranked_results', {}) if reranked: for category in ['summaries', 'statements', 'chunks', 'entities']: items = reranked.get(category, []) if isinstance(items, list): content_list.extend(items) - # 处理时间搜索结果 + # Process time search results time_search = results.get('time_search', {}) if time_search: if isinstance(time_search, dict): @@ -128,7 +169,7 @@ async def clean_databases(data) -> str: elif isinstance(time_search, list): content_list.extend(time_search) - # 提取文本内容 + # Extract text content text_parts = [] for item in content_list: if isinstance(item, dict): @@ -146,10 +187,19 @@ async def clean_databases(data) -> str: async def retrieve_nodes(state: ReadState) -> ReadState: - ''' - - 模型信息 - ''' + """ + Retrieve information using simplified search approach + + Processes extended problems from previous nodes and performs retrieval + using either RAG or hybrid search based on storage type. Handles concurrent + processing of multiple questions and deduplicates results. + + Args: + state: ReadState containing problem extensions and configuration + + Returns: + ReadState: Updated state with retrieval results and intermediate outputs + """ problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') @@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: problem_list.append(data) logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - # 创建异步任务处理单个问题 + # Create async task to process individual questions async def process_question_nodes(idx, question): try: # Prepare search parameters based on storage type @@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: } } - # 并发处理所有问题 + # Process all questions concurrently tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] databases_anser = await asyncio.gather(*tasks) databases_data = { @@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState: - # 从state中获取end_user_id + """ + Advanced retrieve function using LangChain agents and tools + + Uses LangChain agents with specialized retrieval tools (time-based and hybrid) + to perform sophisticated information retrieval. Supports both RAG and traditional + memory storage approaches with concurrent processing and result deduplication. + + Args: + state: ReadState containing problem extensions and configuration + + Returns: + ReadState: Updated state with retrieval results and intermediate outputs + """ + # Get end_user_id from state import time start = time.time() problem_extension = state.get('problem_extension', '')['context'] @@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState: system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) - # 创建异步任务处理单个问题 + # Create async task to process individual questions import asyncio - # 在模块级别定义信号量,限制最大并发数 - SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 + # Define semaphore at module level to limit maximum concurrency + SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations async def process_question(idx, question): - async with SEMAPHORE: # 限制并发 + async with SEMAPHORE: # Limit concurrency try: if storage_type == "rag" and user_rag_memory_id: retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) else: cleaned_query = question - # 使用 asyncio 在线程池中运行同步的 agent.invoke + # Use asyncio to run synchronous agent.invoke in thread pool import asyncio response = await asyncio.get_event_loop().run_in_executor( None, @@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState: } } - # 并发处理所有问题 + # Process all questions concurrently import asyncio tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] databases_anser = await asyncio.gather(*tasks) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 87606bf8..030acc9a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -23,18 +23,39 @@ logger = get_agent_logger(__name__) class SummaryNodeService(LLMServiceMixin): - """总结节点服务类""" + """ + Summary node service class + + Handles summary generation operations using LLM services. Inherits from + LLMServiceMixin to provide structured LLM calling capabilities for + generating summaries from retrieved information. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance summary_service = SummaryNodeService() async def rag_config(state): + """ + Configure RAG (Retrieval-Augmented Generation) settings for summary operations + + Creates configuration for knowledge base retrieval including similarity thresholds, + weights, and reranker settings specifically for summary generation. + + Args: + state: Current state containing user_rag_memory_id + + Returns: + dict: RAG configuration dictionary with knowledge base settings + """ user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { "knowledge_bases": [ @@ -54,6 +75,23 @@ async def rag_config(state): async def rag_knowledge(state, question): + """ + Retrieve knowledge using RAG approach for summary generation + + Performs knowledge retrieval from configured knowledge bases using the + provided question and returns formatted results for summary processing. + + Args: + state: Current state containing configuration + question: Question to search for in knowledge base + + Returns: + tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results) + - retrieval_knowledge: List of retrieved knowledge chunks + - clean_content: Formatted content string + - cleaned_query: Processed query string + - raw_results: Raw retrieval results + """ kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') user_rag_memory_id = state.get("user_rag_memory_id", '') @@ -74,6 +112,18 @@ async def rag_knowledge(state, question): async def summary_history(state: ReadState) -> ReadState: + """ + Retrieve conversation history for summary context + + Gets the conversation history for the current user to provide context + for summary generation operations. + + Args: + state: ReadState containing end_user_id + + Returns: + ReadState: Conversation history data + """ end_user_id = state.get("end_user_id", '') history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history @@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState: async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, search_mode) -> str: """ - 增强的summary_llm函数,包含更好的错误处理和数据验证 + Enhanced summary_llm function with better error handling and data validation + + Generates summaries using LLM with structured output. Includes fallback mechanisms + for handling LLM failures and provides robust error recovery. + + Args: + state: ReadState containing current context + history: Conversation history for context + retrieve_info: Retrieved information to summarize + template_name: Jinja2 template name for prompt generation + operation_name: Type of operation (summary, input_summary, retrieve_summary) + response_model: Pydantic model for structured output + search_mode: Search mode flag ("0" for simple, "1" for complex) + + Returns: + str: Generated summary text or fallback message """ data = state.get("data", '') - # 构建系统提示词 + # Build system prompt if str(search_mode) == "0": system_prompt = await summary_service.template_service.render_template( template_name=template_name, @@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o retrieve_info=retrieve_info ) try: - # 使用优化的LLM服务进行结构化输出 + # Use optimized LLM service for structured output with get_db_context() as db_session: structured = await summary_service.call_llm_structured( state=state, @@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o response_model=response_model, fallback_value=None ) - # 验证结构化响应 + # Validate structured response if structured is None: logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" - # 根据操作类型提取答案 + # Extract answer based on operation type if operation_name == "summary": aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" else: - # 处理RetrieveSummaryResponse + # Handle RetrieveSummaryResponse if hasattr(structured, 'data') and structured.data: aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" else: logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" - # 验证答案不为空 + # Validate answer is not empty if not aimessages or aimessages.strip() == "": aimessages = "信息不足,无法回答" @@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o except Exception as e: logger.error(f"结构化输出失败: {e}", exc_info=True) - # 尝试非结构化输出作为fallback + # Try unstructured output as fallback try: logger.info("尝试非结构化输出作为fallback") response = await summary_service.call_llm_simple( @@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) if response and response.strip(): - # 简单清理响应 + # Simple response cleaning cleaned_response = response.strip() - # 移除可能的JSON标记 + # Remove possible JSON markers if cleaned_response.startswith('```'): lines = cleaned_response.split('\n') cleaned_response = '\n'.join(lines[1:-1]) @@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o async def summary_redis_save(state: ReadState, aimessages) -> ReadState: + """ + Save summary results to Redis session storage + + Stores the generated summary and user query in Redis for session management + and conversation history tracking. + + Args: + state: ReadState containing user and query information + aimessages: Generated summary message to save + + Returns: + ReadState: Updated state after saving to Redis + """ data = state.get("data", '') end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( @@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState: async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: + """ + Format summary results for different output types + + Creates structured output formats for both input summary and retrieval summary + operations, including metadata and intermediate results for frontend display. + + Args: + state: ReadState containing storage and user information + aimessages: Generated summary message + raw_results: Raw search/retrieval results + + Returns: + tuple: (input_summary, retrieve_summary) formatted result dictionaries + """ storage_type = state.get("storage_type", '') user_rag_memory_id = state.get("user_rag_memory_id", '') data = state.get("data", '') @@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState async def Input_Summary(state: ReadState) -> ReadState: + """ + Generate quick input summary from retrieved information + + Performs fast retrieval and generates a quick summary response for user queries. + This function prioritizes speed by only searching summary nodes and provides + immediate feedback to users. + + Args: + state: ReadState containing user query, storage configuration, and context + + Returns: + ReadState: Dictionary containing summary results with status and metadata + """ start = time.time() storage_type = state.get("storage_type", '') memory_config = state.get('memory_config', None) @@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState: async def Retrieve_Summary(state: ReadState) -> ReadState: + """ + Generate comprehensive summary from retrieved expansion issues + + Processes retrieved expansion issues and generates a detailed summary using LLM. + This function handles complex retrieval results and provides comprehensive answers + based on expanded query results. + + Args: + state: ReadState containing retrieve data with expansion issues + + Returns: + ReadState: Dictionary containing comprehensive summary results + """ retrieve = state.get("retrieve", '') history = await summary_history(state) import json @@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState: duration = 0.0 log_time('Retrieval summary', duration) - # 修复协程调用 - 先await,然后访问返回值 + # Fixed coroutine call - await first, then access return value summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary": summary} async def Summary(state: ReadState) -> ReadState: + """ + Generate final comprehensive summary from verified data + + Creates the final summary using verified expansion issues and conversation history. + This function processes verified data to generate the most comprehensive and + accurate response to user queries. + + Args: + state: ReadState containing verified data and query information + + Returns: + ReadState: Dictionary containing final summary results + """ start = time.time() query = state.get("data", '') verify = state.get("verify", '') @@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState: duration = 0.0 log_time('Retrieval summary', duration) - # 修复协程调用 - 先await,然后访问返回值 + # Fixed coroutine call - await first, then access return value summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary": summary} async def Summary_fails(state: ReadState) -> ReadState: + """ + Generate fallback summary when normal summary process fails + + Provides a fallback summary generation mechanism when the standard summary + process encounters errors or fails to produce satisfactory results. Uses + a specialized failure template to handle edge cases. + + Args: + state: ReadState containing verified data and failure context + + Returns: + ReadState: Dictionary containing fallback summary results + """ storage_type = state.get("storage_type", '') user_rag_memory_id = state.get("user_rag_memory_id", '') history = await summary_history(state) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index 3f7b491e..3a04b411 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -18,24 +18,46 @@ logger = get_agent_logger(__name__) class VerificationNodeService(LLMServiceMixin): - """验证节点服务类""" + """ + Verification node service class + + Handles data verification operations using LLM services. Inherits from + LLMServiceMixin to provide structured LLM calling capabilities for + verifying and validating retrieved information. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance verification_service = VerificationNodeService() async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): - """处理验证结果并生成输出格式""" + """ + Process verification results and generate output format + + Transforms VerificationResult objects into structured output format suitable + for frontend consumption. Handles conversion of VerificationItem objects to + dictionary format and adds metadata for tracking. + + Args: + state: ReadState containing storage and user configuration + messages_deal: VerificationResult containing verification outcomes + + Returns: + dict: Formatted verification result with status and metadata + """ storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') - # 将 VerificationItem 对象转换为字典列表 + # Convert VerificationItem objects to dictionary list verified_data = [] if messages_deal.expansion_issue: for item in messages_deal.expansion_issue: @@ -89,7 +111,7 @@ async def Verify(state: ReadState): logger.info("Verify: 开始渲染模板") - # 生成 JSON schema 以指导 LLM 输出正确格式 + # Generate JSON schema to guide LLM output format json_schema = VerificationResult.model_json_schema() system_prompt = await verification_service.template_service.render_template( @@ -104,8 +126,8 @@ async def Verify(state: ReadState): # 使用优化的LLM服务,添加超时保护 logger.info("Verify: 开始调用 LLM") try: - # 添加 asyncio.wait_for 超时包裹,防止无限等待 - # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) + # Add asyncio.wait_for timeout wrapper to prevent infinite waiting + # Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds) with get_db_context() as db_session: structured = await asyncio.wait_for( @@ -122,7 +144,7 @@ async def Verify(state: ReadState): "reason": "验证失败或超时" } ), - timeout=150.0 # 150秒超时 + timeout=150.0 # 150 second timeout ) logger.info(f"Verify: LLM 调用完成,result={structured}") except asyncio.TimeoutError: diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index cba1b230..bddae618 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import ( @asynccontextmanager async def make_read_graph(): - """创建并返回 LangGraph 工作流""" + """ + Create and return a LangGraph workflow for memory reading operations + + Builds a state graph workflow that handles memory retrieval, problem analysis, + verification, and summarization. The workflow includes nodes for content input, + problem splitting, retrieval, verification, and various summary operations. + + Yields: + StateGraph: Compiled LangGraph workflow for memory reading + + Raises: + Exception: If workflow creation fails + """ try: # Build workflow graph workflow = StateGraph(ReadState) @@ -48,7 +60,7 @@ async def make_read_graph(): workflow.add_node("Summary", Summary) workflow.add_node("Summary_fails", Summary_fails) - # 添加边 + # Add edges to define workflow flow workflow.add_edge(START, "content_input") workflow.add_conditional_edges("content_input", Split_continue) workflow.add_edge("Input_Summary", END) @@ -63,7 +75,7 @@ async def make_read_graph(): '''-----''' # workflow.add_edge("Retrieve", END) - # 编译工作流 + # Compile workflow graph = workflow.compile() yield graph @@ -72,108 +84,3 @@ async def make_read_graph(): raise finally: print("工作流创建完成") - - -async def main(): - """主函数 - 运行工作流""" - message = "昨天有什么好看的电影" - end_user_id = '88a459f5_text09' # 组ID - storage_type = 'neo4j' # 存储类型 - search_switch = '1' # 搜索开关 - user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID - - # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=17, # 改为整数 - service_name="MemoryAgentService" - ) - import time - start = time.time() - try: - async with make_read_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "end_user_id": end_user_id - , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, - "memory_config": memory_config} - # 获取节点更新信息 - _intermediate_outputs = [] - summary = '' - - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - print(f"处理节点: {node_name}") - - # 处理不同Summary节点的返回结构 - if 'Summary' in node_name: - if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: - summary = node_data['InputSummary']['summary_result'] - elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: - summary = node_data['RetrieveSummary']['summary_result'] - elif 'summary' in node_data and 'summary_result' in node_data['summary']: - summary = node_data['summary']['summary_result'] - elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: - summary = node_data['SummaryFails']['summary_result'] - - spit_data = node_data.get('spit_data', {}).get('_intermediate', None) - if spit_data and spit_data != [] and spit_data != {}: - _intermediate_outputs.append(spit_data) - - # Problem_Extension 节点 - problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) - if problem_extension and problem_extension != [] and problem_extension != {}: - _intermediate_outputs.append(problem_extension) - - # Retrieve 节点 - retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) - if retrieve_node and retrieve_node != [] and retrieve_node != {}: - _intermediate_outputs.extend(retrieve_node) - - # Verify 节点 - verify_n = node_data.get('verify', {}).get('_intermediate', None) - if verify_n and verify_n != [] and verify_n != {}: - _intermediate_outputs.append(verify_n) - - # Summary 节点 - summary_n = node_data.get('summary', {}).get('_intermediate', None) - if summary_n and summary_n != [] and summary_n != {}: - _intermediate_outputs.append(summary_n) - - # # 过滤掉空值 - # _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - # - # # 优化搜索结果 - # print("=== 开始优化搜索结果 ===") - # optimized_outputs = merge_multiple_search_results(_intermediate_outputs) - # result=reorder_output_results(optimized_outputs) - # # 保存优化后的结果到文件 - # with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f: - # import json - # f.write(json.dumps(result, indent=4, ensure_ascii=False)) - # - print(f"=== 最终摘要 ===") - print(summary) - - except Exception as e: - import traceback - traceback.print_exc() - finally: - db_session.close() - - end = time.time() - print(100 * 'y') - print(f"总耗时: {end - start}s") - print(100 * 'y') - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 895f61ac..ddb6ca3e 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -22,57 +22,73 @@ logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - # RAG 模式:组合消息为字符串格式(保持原有逻辑) + """ + Write messages to RAG storage system + + Combines user and AI messages into a single string format and stores them + in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. + + Args: + end_user_id: User identifier for the conversation + user_message: User's input message content + ai_message: AI's response message content + user_rag_memory_id: RAG memory identifier for storage location + """ + # RAG mode: combine messages into string format (maintain original logic) combined_message = f"user: {user_message}\nassistant: {ai_message}" await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id, long_term_messages=[]): """ - 写入记忆(支持结构化消息) - + Write memory with structured message support + + Handles memory writing operations for different storage types (Neo4j/RAG). + Supports both individual message pairs and batch long-term message processing. + Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + storage_type: Storage type identifier ("neo4j" or "rag") + end_user_id: Terminal user identifier + user_message: User message content + ai_message: AI response content + user_rag_memory_id: RAG memory identifier + actual_end_user_id: Actual user identifier for storage + actual_config_id: Configuration identifier + long_term_messages: Optional list of structured messages for batch processing + + Logic explanation: + - RAG mode: Combines user_message and ai_message into string format, maintains original logic + - Neo4j mode: Uses structured message lists + 1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant] + 2. If only user_message exists: Creates single user message [user] (for historical memory scenarios) + 3. Each message is converted to independent Chunk, preserving speaker field """ db = next(get_db()) try: actual_config_id = resolve_config_id(actual_config_id, db) - # Neo4j 模式:使用结构化消息列表 + # Neo4j mode: Use structured message lists structured_messages = [] - # 始终添加用户消息(如果不为空) + # Always add user message (if not empty) if isinstance(user_message, str) and user_message.strip() != "": structured_messages.append({"role": "user", "content": user_message}) - # 只有当 AI 回复不为空时才添加 assistant 消息 + # Only add assistant message when AI reply is not empty if isinstance(ai_message, str) and ai_message.strip() != "": structured_messages.append({"role": "assistant", "content": ai_message}) - # 如果提供了 long_term_messages,使用它替代 structured_messages + # If long_term_messages provided, use it to replace structured_messages if long_term_messages and isinstance(long_term_messages, list): structured_messages = long_term_messages elif long_term_messages and isinstance(long_term_messages, str): - # 如果是 JSON 字符串,先解析 + # If it's a JSON string, parse it first try: structured_messages = json.loads(long_term_messages) except json.JSONDecodeError: logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") - # 如果没有消息,直接返回 + # If no messages, return directly if not structured_messages: logger.warning(f"No messages to write for user {actual_end_user_id}") return @@ -80,11 +96,11 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me logger.info( f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: JSON 字符串格式的消息列表 - str(actual_config_id), # config_id: 配置ID字符串 + actual_end_user_id, # end_user_id: User ID + structured_messages, # message: JSON string format message list + str(actual_config_id), # config_id: Configuration ID string storage_type, # storage_type: "neo4j" - user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) @@ -93,6 +109,20 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me db.close() async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + """ + Save long-term memory data to database + + Handles the storage of long-term memory data based on different strategies + (chunk-based or aggregate-based) and manages the transition from short-term + to long-term memory storage. + + Args: + long_term_messages: Long-term message data to be saved + actual_config_id: Configuration identifier for memory settings + end_user_id: User identifier for memory association + type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + scope: Scope/window size for memory processing + """ with get_db_context() as db_session: repo = LongTermMemoryRepository(db_session) @@ -113,16 +143,20 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type, -'''根据窗口''' +"""Window-based dialogue processing""" async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): - ''' - 根据窗口获取redis数据,写入neo4j: - Args: - end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - scope:窗口大小 - ''' + """ + Process dialogue based on window size and write to Neo4j + + Manages conversation data based on a sliding window approach. When the window + reaches the specified scope size, it triggers long-term memory storage to Neo4j. + + Args: + end_user_id: Terminal user identifier + memory_config: Memory configuration object containing settings + langchain_messages: Original message data list + scope: Window size determining when to trigger long-term storage + """ scope=scope is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: @@ -135,7 +169,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): elif int(is_end_user_id) == int(scope): logger.info('写入长期记忆NEO4J') formatted_messages = (redis_messages) - # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: @@ -148,14 +182,19 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): count_store.save_sessions_count(end_user_id, 1, langchain_messages) -"""根据时间""" +"""Time-based memory processing""" async def memory_long_term_storage(end_user_id,memory_config,time): - ''' - 根据时间获取redis数据,写入neo4j: - Args: - end_user_id: 终端用户ID - memory_config: 内存配置对象 - ''' + """ + Process memory storage based on time intervals and write to Neo4j + + Retrieves Redis data based on time intervals and writes it to Neo4j for + long-term storage. This function handles time-based memory consolidation. + + Args: + end_user_id: Terminal user identifier + memory_config: Memory configuration object containing settings + time: Time interval for data retrieval + """ long_time_data = write_store.find_user_recent_sessions(end_user_id, time) format_messages = (long_time_data) messages=[] @@ -166,19 +205,25 @@ async def memory_long_term_storage(end_user_id,memory_config,time): if format_messages!=[]: await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, memory_config, messages) -'''聚合判断''' +"""Aggregation judgment processing""" async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ - 聚合判断函数:判断输入句子和历史消息是否描述同一事件 + Aggregation judgment function: determine if input sentence and historical messages describe the same event + + Uses LLM-based analysis to determine whether new messages should be aggregated with existing + historical data or stored as separate events. This helps optimize memory storage and retrieval. Args: - end_user_id: 终端用户ID - ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - memory_config: 内存配置对象 + end_user_id: Terminal user identifier + ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + memory_config: Memory configuration object containing LLM settings + + Returns: + dict: Aggregation judgment result containing is_same_event flag and processed output """ try: - # 1. 获取历史会话数据(使用新方法) + # 1. Get historical session data (using new method) result = write_store.get_all_sessions_by_end_user_id(end_user_id) history = await format_parsing(result) if not result: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index fcbb18e3..bee77ddf 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -13,30 +13,43 @@ from app.core.memory.src.search import ( ) def extract_tool_message_content(response): - """从agent响应中提取ToolMessage内容和工具名称""" + """ + Extract ToolMessage content and tool names from agent response + + Parses agent response messages to extract tool execution results and metadata. + Handles JSON parsing and provides structured access to tool output data. + + Args: + response: Agent response dictionary containing messages + + Returns: + dict: Dictionary containing tool_name and parsed content, or None if no tool message found + - tool_name: Name of the executed tool + - content: Parsed tool execution result (JSON or raw text) + """ messages = response.get('messages', []) for message in messages: if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): - # 这是一个ToolMessage + # This is a ToolMessage tool_content = message.content tool_name = None - # 尝试获取工具名称 + # Try to get tool name if hasattr(message, 'name'): tool_name = message.name elif hasattr(message, 'tool_name'): tool_name = message.tool_name try: - # 解析JSON内容 + # Parse JSON content parsed_content = json.loads(tool_content) return { 'tool_name': tool_name, 'content': parsed_content } except json.JSONDecodeError: - # 如果不是JSON格式,直接返回内容 + # If not JSON format, return content directly return { 'tool_name': tool_name, 'content': tool_content @@ -46,26 +59,48 @@ def extract_tool_message_content(response): class TimeRetrievalInput(BaseModel): - """时间检索工具的输入模式""" + """ + Input schema for time retrieval tool + + Defines the expected input parameters for time-based retrieval operations. + Used for validation and documentation of tool parameters. + + Attributes: + context: User input query content for search + end_user_id: Group ID for filtering search results, defaults to test user + """ context: str = Field(description="用户输入的查询内容") end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") def create_time_retrieval_tool(end_user_id: str): """ - 创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range + + Creates a specialized time-based retrieval tool that searches for statements within + specified time ranges. Includes field cleaning functionality to remove unnecessary + metadata from search results. + + Args: + end_user_id: User identifier for scoping search results + + Returns: + function: Configured TimeRetrievalWithGroupId tool function """ def clean_temporal_result_fields(data): """ - 清理时间搜索结果中不需要的字段,并修改结构 + Clean unnecessary fields from temporal search results and modify structure + + Removes metadata fields that are not needed for end-user consumption and + restructures the response format for better usability. Args: - data: 要清理的数据 + data: Data to be cleaned (dict, list, or other types) Returns: - 清理后的数据 + Cleaned data with unnecessary fields removed """ - # 需要过滤的字段列表 + # List of fields to filter out fields_to_remove = { 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', 'valid_at', 'invalid_at', 'statement_ids' @@ -75,9 +110,9 @@ def create_time_retrieval_tool(end_user_id: str): cleaned = {} for key, value in data.items(): if key == 'statements' and isinstance(value, dict) and 'statements' in value: - # 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]} + # Change statements: {"statements": [...]} to time_search: {"statements": [...]} cleaned_value = clean_temporal_result_fields(value) - # 进一步将内部的 statements 改为 time_search + # Further change internal statements to time_search if 'statements' in cleaned_value: cleaned['results'] = { 'time_search': cleaned_value['statements'] @@ -95,22 +130,29 @@ def create_time_retrieval_tool(end_user_id: str): @tool def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: """ - 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 - 显式接收参数: - - context: 查询上下文内容 - - start_date: 开始时间(可选,格式:YYYY-MM-DD) - - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - end_user_id_param: 组ID(可选,用于覆盖默认组ID) - - clean_output: 是否清理输出中的元数据字段 - -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields + + Performs time-based search operations with automatic metadata filtering. Supports + flexible date range specification and provides clean, user-friendly output. + + Explicit parameters: + - context: Query context content + - start_date: Start time (optional, format: YYYY-MM-DD) + - end_date: End time (optional, format: YYYY-MM-DD) + - end_user_id_param: Group ID (optional, overrides default group ID) + - clean_output: Whether to clean metadata fields from output + - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") + + Returns: + str: JSON formatted search results with temporal data """ async def _async_search(): - # 使用传入的参数或默认值 + # Use passed parameters or default values actual_end_user_id = end_user_id_param or end_user_id actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") - # 基本时间搜索 + # Basic time search results = await search_by_temporal( end_user_id=actual_end_user_id, start_date=actual_start_date, @@ -118,7 +160,7 @@ def create_time_retrieval_tool(end_user_id: str): limit=10 ) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_temporal_result_fields(results) else: @@ -131,20 +173,28 @@ def create_time_retrieval_tool(end_user_id: str): @tool def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: """ - 优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 - 显式接收参数: - - context: 查询内容 - - days_back: 向前搜索的天数,默认7天 - - start_date: 开始时间(可选,格式:YYYY-MM-DD) - - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - clean_output: 是否清理输出中的元数据字段 - - end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields + + Performs combined keyword and temporal search operations with automatic metadata + filtering. Provides more targeted search results by combining content relevance + with time-based filtering. + + Explicit parameters: + - context: Query content for keyword matching + - days_back: Number of days to search backwards, default 7 days + - start_date: Start time (optional, format: YYYY-MM-DD) + - end_date: End time (optional, format: YYYY-MM-DD) + - clean_output: Whether to clean metadata fields from output + - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") + + Returns: + str: JSON formatted search results combining keyword and temporal data """ async def _async_search(): actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") - # 关键词时间搜索 + # Keyword time search results = await search_by_keyword_temporal( query_text=context, end_user_id=end_user_id, @@ -153,7 +203,7 @@ def create_time_retrieval_tool(end_user_id: str): limit=15 ) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_temporal_result_fields(results) else: @@ -168,25 +218,34 @@ def create_time_retrieval_tool(end_user_id: str): def create_hybrid_retrieval_tool_async(memory_config, **search_params): """ - 创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段 + Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields + + Creates an advanced hybrid search tool that combines multiple search strategies + (keyword, vector, hybrid) with automatic result cleaning and formatting. Args: - memory_config: 内存配置对象 - **search_params: 搜索参数,包含end_user_id, limit, include等 + memory_config: Memory configuration object containing LLM and search settings + **search_params: Search parameters including end_user_id, limit, include, etc. + + Returns: + function: Configured HybridSearch tool function with async capabilities """ def clean_result_fields(data): """ - 递归清理结果中不需要的字段 + Recursively clean unnecessary fields from results + + Removes metadata fields that are not needed for end-user consumption, + improving readability and reducing response size. Args: - data: 要清理的数据(可能是字典、列表或其他类型) + data: Data to be cleaned (can be dict, list, or other types) Returns: - 清理后的数据 + Cleaned data with unnecessary fields removed """ - # 需要过滤的字段列表 - # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # List of fields to filter out + # TODO: fact_summary functionality temporarily disabled, will be enabled after future development fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', @@ -194,17 +253,17 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): } if isinstance(data, dict): - # 对字典进行清理 + # Clean dictionary cleaned = {} for key, value in data.items(): if key not in fields_to_remove: - cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 + cleaned[key] = clean_result_fields(value) # Recursively clean nested data return cleaned elif isinstance(data, list): - # 对列表中的每个元素进行清理 + # Clean each element in list return [clean_result_fields(item) for item in data] else: - # 其他类型直接返回 + # Return other types directly return data @tool @@ -216,49 +275,55 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): rerank_alpha: float = 0.6, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, - clean_output: bool = True # 新增:是否清理输出字段 + clean_output: bool = True # New: whether to clean output fields ) -> str: """ - 优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 + Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields + + Provides comprehensive search capabilities combining multiple search strategies + with intelligent result ranking and automatic metadata filtering for clean output. Args: - context: 查询内容 - search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') - limit: 结果数量限制 - end_user_id: 组ID,用于过滤搜索结果 - rerank_alpha: 重排序权重参数 - use_forgetting_rerank: 是否使用遗忘重排序 - use_llm_rerank: 是否使用LLM重排序 - clean_output: 是否清理输出中的元数据字段 + context: Query content for search + search_type: Search type ('keyword', 'embedding', 'hybrid') + limit: Result quantity limit + end_user_id: Group ID for filtering search results + rerank_alpha: Reranking weight parameter for result scoring + use_forgetting_rerank: Whether to use forgetting-based reranking + use_llm_rerank: Whether to use LLM-based reranking + clean_output: Whether to clean metadata fields from output + + Returns: + str: JSON formatted comprehensive search results """ try: - # 导入run_hybrid_search函数 + # Import run_hybrid_search function from app.core.memory.src.search import run_hybrid_search - # 合并参数,优先使用传入的参数 + # Merge parameters, prioritize passed parameters final_params = { "query_text": context, "search_type": search_type, "end_user_id": end_user_id or search_params.get("end_user_id"), "limit": limit or search_params.get("limit", 10), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), - "output_path": None, # 不保存到文件 + "output_path": None, # Don't save to file "memory_config": memory_config, "rerank_alpha": rerank_alpha, "use_forgetting_rerank": use_forgetting_rerank, "use_llm_rerank": use_llm_rerank } - # 执行混合检索 + # Execute hybrid retrieval raw_results = await run_hybrid_search(**final_params) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_result_fields(raw_results) else: cleaned_results = raw_results - # 格式化返回结果 + # Format return results formatted_results = { "search_query": context, "search_type": search_type, @@ -281,11 +346,17 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): def create_hybrid_retrieval_tool_sync(memory_config, **search_params): """ - 创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 + Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields + + Creates a synchronous wrapper around the async hybrid search functionality, + making it compatible with synchronous tool execution environments. Args: - memory_config: 内存配置对象 - **search_params: 搜索参数 + memory_config: Memory configuration object containing search settings + **search_params: Search parameters for configuration + + Returns: + function: Configured HybridSearchSync tool function """ @tool def HybridSearchSync( @@ -296,17 +367,23 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): clean_output: bool = True ) -> str: """ - 优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 + Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields + + Provides the same hybrid search capabilities as the async version but in a + synchronous execution context. Automatically handles async-to-sync conversion. Args: - context: 查询内容 - search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') - limit: 结果数量限制 - end_user_id: 组ID,用于过滤搜索结果 - clean_output: 是否清理输出中的元数据字段 + context: Query content for search + search_type: Search type ('keyword', 'embedding', 'hybrid') + limit: Result quantity limit + end_user_id: Group ID for filtering search results + clean_output: Whether to clean metadata fields from output + + Returns: + str: JSON formatted search results """ async def _async_search(): - # 创建异步工具并执行 + # Create async tool and execute async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) return await async_tool.ainvoke({ "context": context, diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index 9ce581ee..aa5e09a6 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -3,14 +3,20 @@ import json from langchain_core.messages import HumanMessage, AIMessage async def format_parsing(messages: list,type:str='string'): """ - 格式化解析消息列表 + Format and parse message lists into different output types + + Processes message lists from storage and converts them into either string format + or dictionary format based on the specified type parameter. Handles JSON parsing + and role-based message organization. Args: - messages: 消息列表 - type: 返回类型 ('string' 或 'dict') + messages: List of message objects from storage containing message data + type: Return type specification ('string' for text format, 'dict' for key-value pairs) Returns: - 格式化后的消息列表 + list: Formatted message list in the specified format + - 'string': List of formatted text messages with role prefixes + - 'dict': List of dictionaries mapping user messages to AI responses """ result = [] user=[] @@ -40,6 +46,18 @@ async def format_parsing(messages: list,type:str='string'): return result async def messages_parse(messages: list | dict): + """ + Parse messages from storage format into user-AI conversation pairs + + Extracts and organizes conversation data from stored message format, + separating user and AI messages and pairing them for database storage. + + Args: + messages: List or dictionary containing stored message data with Query fields + + Returns: + list: List of dictionaries containing user-AI message pairs for database storage + """ user=[] ai=[] database=[] @@ -58,6 +76,19 @@ async def messages_parse(messages: list | dict): async def agent_chat_messages(user_content,ai_content): + """ + Create structured chat message format for agent conversations + + Formats user and AI content into a standardized message structure suitable + for agent processing and storage. Creates role-based message objects. + + Args: + user_content: User's message content string + ai_content: AI's response content string + + Returns: + list: List of structured message dictionaries with role and content fields + """ messages = [ { "role": "user", diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 1134acc7..15009955 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -43,6 +43,19 @@ async def make_write_graph(): yield graph async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + """ + Handle long-term memory storage with different strategies + + Supports multiple storage strategies including chunk-based, time-based, + and aggregate judgment approaches for long-term memory persistence. + + Args: + long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') + langchain_messages: List of messages to store + memory_config: Memory configuration identifier + end_user_id: User group identifier + scope: Scope parameter for chunk-based storage (default: 6) + """ from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment from app.core.memory.agent.utils.redis_tool import write_store write_store.save_session_write(end_user_id, (langchain_messages)) @@ -53,26 +66,40 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ config_id=memory_config, # 改为整数 service_name="MemoryAgentService" ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' + if long_term_type==AgentMemory_Long_Term.STRATEGY_CHUNK: + '''Strategy 1: Dialogue window with 6 rounds of conversation''' await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - """方案三:聚合判断""" + if long_term_type==AgentMemory_Long_Term.STRATEGY_TIME: + """Time-based strategy""" + await memory_long_term_storage(end_user_id, memory_config,AgentMemory_Long_Term.TIME_SCOPE) + if long_term_type==AgentMemory_Long_Term.STRATEGY_AGGREGATE: + """Strategy 3: Aggregate judgment""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + """ + Write long-term memory with different storage types + + Handles both RAG-based storage and traditional memory storage approaches. + For traditional storage, uses chunk-based strategy with paired user-AI messages. + + Args: + storage_type: Type of storage (RAG or traditional) + end_user_id: User group identifier + message_chat: User message content + aimessages: AI response messages + user_rag_memory_id: RAG memory identifier + actual_config_id: Actual configuration ID + """ from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) else: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE long_term_messages = await agent_chat_messages(message_chat, aimessages) diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 09c7ef3d..b2a594c6 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -1,11 +1,11 @@ """ -自我反思引擎实现 +Self-Reflection Engine Implementation -该模块实现了记忆系统的自我反思功能,包括: -1. 基于时间的反思 - 根据时间周期触发反思 -2. 基于事实的反思 - 检测记忆冲突并解决 -3. 综合反思 - 整合多种反思策略 -4. 反思结果应用 - 更新记忆库 +This module implements the self-reflection functionality of the memory system, including: +1. Time-based reflection - Triggers reflection based on time cycles +2. Fact-based reflection - Detects and resolves memory conflicts +3. Comprehensive reflection - Integrates multiple reflection strategies +4. Reflection result application - Updates memory database """ import asyncio @@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import ( ) from pydantic import BaseModel -# 配置日志 +# Configure logging _root_logger = logging.getLogger() if not _root_logger.handlers: logging.basicConfig( @@ -49,35 +49,62 @@ else: _root_logger.setLevel(logging.INFO) class TranslationResponse(BaseModel): - """翻译响应模型""" + """Translation response model for language conversion""" data: str + class ReflectionRange(str, Enum): - """反思范围枚举""" - PARTIAL = "partial" # 从检索结果中反思 - ALL = "all" # 从整个数据库中反思 + """ + Reflection range enumeration + + Defines the scope of data to be included in reflection operations. + """ + PARTIAL = "partial" # Reflect from retrieval results + ALL = "all" # Reflect from entire database class ReflectionBaseline(str, Enum): - """反思基线枚举""" - TIME = "TIME" # 基于时间的反思 - FACT = "FACT" # 基于事实的反思 - HYBRID = "HYBRID" # 混合反思 + """ + Reflection baseline enumeration + + Defines the strategy or approach used for reflection operations. + """ + TIME = "TIME" # Time-based reflection + FACT = "FACT" # Fact-based reflection + HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies class ReflectionConfig(BaseModel): - """反思引擎配置""" + """ + Reflection engine configuration + + Defines all configuration parameters for the reflection engine including + operation modes, model settings, and evaluation criteria. + + Attributes: + enabled: Whether reflection engine is enabled + iteration_period: Reflection cycle period (e.g., "3" hours) + reflexion_range: Scope of reflection (PARTIAL or ALL) + baseline: Reflection strategy (TIME, FACT, or HYBRID) + model_id: LLM model identifier for reflection operations + end_user_id: User identifier for scoped operations + output_example: Example output format for guidance + memory_verify: Enable memory verification checks + quality_assessment: Enable quality assessment evaluation + violation_handling_strategy: Strategy for handling violations + language_type: Language type for output ("zh" or "en") + """ enabled: bool = False - iteration_period: str = "3" # 反思周期 + iteration_period: str = "3" # Reflection cycle period reflexion_range: ReflectionRange = ReflectionRange.PARTIAL baseline: ReflectionBaseline = ReflectionBaseline.TIME - model_id: Optional[str] = None # 模型ID + model_id: Optional[str] = None # Model ID end_user_id: Optional[str] = None - output_example: Optional[str] = None # 输出示例 + output_example: Optional[str] = None # Output example - # 评估相关字段 - memory_verify: bool = True # 记忆验证 - quality_assessment: bool = True # 质量评估 - violation_handling_strategy: str = "warn" # 违规处理策略 + # Evaluation related fields + memory_verify: bool = True # Memory verification + quality_assessment: bool = True # Quality assessment + violation_handling_strategy: str = "warn" # Violation handling strategy language_type: str = "zh" class Config: @@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel): class ReflectionResult(BaseModel): - """反思结果""" + """ + Reflection operation result + + Contains comprehensive information about the outcome of a reflection operation + including success status, metrics, and execution details. + + Attributes: + success: Whether the reflection operation succeeded + message: Descriptive message about the operation result + conflicts_found: Number of conflicts detected during reflection + conflicts_resolved: Number of conflicts successfully resolved + memories_updated: Number of memory entries updated in database + execution_time: Total time taken for the reflection operation + details: Additional details about the operation (optional) + """ success: bool message: str conflicts_found: int = 0 @@ -97,9 +138,22 @@ class ReflectionResult(BaseModel): class ReflectionEngine: """ - 自我反思引擎 - - 负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。 + Self-Reflection Engine + + Responsible for executing memory system self-reflection operations including + conflict detection, conflict resolution, and memory updates. Supports multiple + reflection strategies and provides comprehensive result tracking. + + The engine can operate in different modes: + - Time-based: Reflects on memories within specific time periods + - Fact-based: Detects and resolves factual conflicts in memories + - Hybrid: Combines multiple reflection strategies + + Attributes: + config: Reflection engine configuration + neo4j_connector: Neo4j database connector + llm_client: Language model client for analysis + Various function handlers for data processing and prompt rendering """ def __init__( @@ -115,18 +169,21 @@ class ReflectionEngine: update_query: Optional[str] = None ): """ - 初始化反思引擎 + Initialize reflection engine + + Sets up the reflection engine with configuration and optional dependencies. + Uses lazy initialization to avoid circular imports and optimize startup time. Args: - config: 反思引擎配置 - neo4j_connector: Neo4j 连接器(可选) - llm_client: LLM 客户端(可选) - get_data_func: 获取数据的函数(可选) - render_evaluate_prompt_func: 渲染评估提示词的函数(可选) - render_reflexion_prompt_func: 渲染反思提示词的函数(可选) - conflict_schema: 冲突结果 Schema(可选) - reflexion_schema: 反思结果 Schema(可选) - update_query: 更新查询语句(可选) + config: Reflection engine configuration object + neo4j_connector: Neo4j connector instance (optional, will be created if not provided) + llm_client: LLM client instance (optional, will be created if not provided) + get_data_func: Function for retrieving data (optional, uses default if not provided) + render_evaluate_prompt_func: Function for rendering evaluation prompts (optional) + render_reflexion_prompt_func: Function for rendering reflection prompts (optional) + conflict_schema: Schema for conflict result validation (optional) + reflexion_schema: Schema for reflection result validation (optional) + update_query: Query string for database updates (optional) """ self.config = config self.neo4j_connector = neo4j_connector @@ -137,14 +194,20 @@ class ReflectionEngine: self.conflict_schema = conflict_schema self.reflexion_schema = reflexion_schema self.update_query = update_query - self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 + self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5 - # 延迟导入以避免循环依赖 + # Lazy import to avoid circular dependencies self._lazy_init_done = False def _lazy_init(self): - """延迟初始化,避免循环导入""" + """ + Lazy initialization to avoid circular imports + + Initializes dependencies only when needed, preventing circular import issues + and optimizing startup performance. Sets up default implementations for + any components not provided during construction. + """ if self._lazy_init_done: return @@ -158,7 +221,7 @@ class ReflectionEngine: factory = MemoryClientFactory(db) self.llm_client = factory.get_llm_client(self.config.model_id) elif isinstance(self.llm_client, str): - # 如果 llm_client 是字符串(model_id),则用它初始化客户端 + # If llm_client is a string (model_id), use it to initialize the client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.services.memory_config_service import MemoryConfigService @@ -172,10 +235,10 @@ class ReflectionEngine: model_config = config_service.get_model_config(model_id) extra_params={ - "temperature": 0.2, # 降低温度提高响应速度和一致性 - "max_tokens": 600, # 限制最大token数 - "top_p": 0.8, # 优化采样参数 - "stream": False, # 确保非流式输出以获得最快响应 + "temperature": 0.2, # Lower temperature for faster response and consistency + "max_tokens": 600, # Limit maximum token count + "top_p": 0.8, # Optimize sampling parameters + "stream": False, # Ensure non-streaming output for fastest response } self.llm_client = OpenAIClient(RedBearModelConfig( @@ -191,7 +254,7 @@ class ReflectionEngine: if self.get_data_func is None: self.get_data_func = get_data - # 导入get_data_statement函数 + # Import get_data_statement function if not hasattr(self, 'get_data_statement'): self.get_data_statement = get_data_statement @@ -223,13 +286,20 @@ class ReflectionEngine: async def execute_reflection(self, host_id) -> ReflectionResult: """ - 执行完整的反思流程 + Execute complete reflection workflow + + Performs the full reflection process including data retrieval, conflict detection, + conflict resolution, and memory updates. This is the main entry point for + reflection operations. + Args: - host_id: 主机ID + host_id: Host identifier for scoping reflection operations + Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive result of the reflection operation including + success status, conflict metrics, and execution time """ - # 延迟初始化 + # Lazy initialization self._lazy_init() if not self.config.enabled: @@ -243,7 +313,7 @@ class ReflectionEngine: print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment) try: - # 1. 获取反思数据 + # 1. Get reflection data reflexion_data, statement_databasets = await self._get_reflexion_data(host_id) if not reflexion_data: return ReflectionResult( @@ -252,7 +322,7 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) - # 2. 检测冲突(基于事实的反思) + # 2. Detect conflicts (fact-based reflection) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) conflict_list=[] for i in conflict_data: @@ -261,7 +331,7 @@ class ReflectionEngine: conflicts_found=0 - # 3. 解决冲突 + # 3. Resolve conflicts solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) if not solved_data: @@ -276,7 +346,7 @@ class ReflectionEngine: logging.info(f"解决了 {conflicts_resolved} 个冲突") - # 4. 应用反思结果(更新记忆库) + # 4. Apply reflection results (update memory database) memories_updated=await self._apply_reflection_results(solved_data) execution_time = asyncio.get_event_loop().time() - start_time @@ -302,7 +372,19 @@ class ReflectionEngine: ) async def Translate(self, text): - # 翻译中文为英文 + """ + Translate Chinese text to English + + Uses the configured LLM to translate Chinese text to English with structured output. + Provides consistent translation format for reflection results. + + Args: + text: Chinese text to be translated + + Returns: + str: Translated English text + """ + # Translate Chinese to English translation_messages = [ { "role": "user", @@ -316,6 +398,19 @@ class ReflectionEngine: ) return response.data async def extract_translation(self,data): + """ + Extract and translate reflection data to English + + Processes reflection data structure and translates all Chinese content to English. + Handles nested data structures including memory verifications, quality assessments, + and reflection data while preserving the original structure. + + Args: + data: Dictionary containing reflection data with Chinese content + + Returns: + dict: Translated data structure with English content + """ end_datas={} end_datas['source_data']=await self.Translate(data['source_data']) quality_assessments = [] @@ -350,6 +445,18 @@ class ReflectionEngine: return end_datas async def reflection_run(self): + """ + Execute reflection workflow with comprehensive data processing + + Performs a complete reflection operation including conflict detection, resolution, + and result formatting. Supports both Chinese and English output based on + configuration settings. + + Returns: + dict: Comprehensive reflection results including source data, memory verifications, + quality assessments, and reflection data. Results are translated to English + if language_type is set to 'en'. + """ self._lazy_init() start_time = time.time() memory_verifies_flag = self.config.memory_verify @@ -367,7 +474,7 @@ class ReflectionEngine: result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(databasets, source_data) - # 遍历数据提取字段 + # Traverse data to extract fields quality_assessments = [] memory_verifies = [] for item in conflict_data: @@ -375,9 +482,9 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - conflicts_found = 0 # 初始化为整数0而不是空字符串 + conflicts_found = 0 # Initialize as integer 0 instead of empty string REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} - # Clearn conflict_data,And memory_verify和quality_assessment + # Clean conflict_data, and memory_verify and quality_assessment cleaned_conflict_data = [] for item in conflict_data: cleaned_item = { @@ -389,7 +496,7 @@ class ReflectionEngine: for item in conflict_data: cleaned_data = [] for row in item.get("data", []): - # 删除 created_at / expired_at + # Remove created_at / expired_at cleaned_row = { k: v for k, v in row.items() @@ -402,7 +509,7 @@ class ReflectionEngine: } cleaned_conflict_data_.append(cleaned_item) print(cleaned_conflict_data_) - # 3. 解决冲突 + # 3. Resolve conflicts solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) if not solved_data: return ReflectionResult( @@ -413,7 +520,7 @@ class ReflectionEngine: ) reflexion_data = [] - # 遍历数据提取reflexion字段 + # Traverse data to extract reflexion fields for item in solved_data: if 'results' in item: for result in item['results']: @@ -431,15 +538,24 @@ class ReflectionEngine: async def extract_fields_from_json(self): - """从example.json中提取source_data和databasets字段""" + """ + Extract source_data and databasets fields from example.json + + Reads reflection example data from the example.json file and extracts + the source data and database statements for testing and demonstration purposes. + + Returns: + tuple: (source_data, databasets) extracted from the example file + Returns empty lists if file reading fails + """ prompt_dir = os.path.join(os.path.dirname(__file__), "example") try: - # 读取JSON文件 + # Read JSON file with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f: data = json.loads(f.read()) - # 提取memory_verify下的字段 + # Extract fields under memory_verify memory_verify = data.get("memory_verify", {}) source_data = memory_verify.get("source_data", []) databasets = memory_verify.get("databasets", []) @@ -451,15 +567,17 @@ class ReflectionEngine: async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: """ - 获取反思数据 - - 根据配置的反思范围获取需要反思的记忆数据。 + Get reflection data from database + + Retrieves memory data for reflection based on the configured reflection range. + Supports both partial (from retrieval results) and full (entire database) modes. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping data retrieval Returns: - List[Any]: 反思数据列表 + tuple: (reflexion_data, statement_data) containing memory data for reflection + Returns empty lists if query fails """ print("=== 获取反思数据 ===") @@ -484,26 +602,29 @@ class ReflectionEngine: async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]: """ - 检测冲突(基于事实的反思) - - 使用 LLM 分析记忆数据,检测其中的冲突。 + Detect conflicts (fact-based reflection) + + Uses LLM to analyze memory data and detect conflicts within the memories. + Performs comprehensive conflict detection including memory verification and + quality assessment based on configuration settings. Args: - data: 待检测的记忆数据 + data: Memory data to be analyzed for conflicts + statement_databasets: Statement database records for context Returns: - List[Any]: 冲突记忆列表 + List[Any]: List of detected conflicts with detailed analysis """ if not data: return [] - # 数据预处理:如果数据量太少,直接返回无冲突 + # Data preprocessing: if data is too small, return no conflicts directly if len(data) < 2: logging.info("数据量不足,无需检测冲突") return [] - # 使用转换后的数据 - # print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 + # Use converted data + # print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs memory_verify = self.config.memory_verify logging.info("====== 冲突检测开始 ======") @@ -512,7 +633,7 @@ class ReflectionEngine: language_type=self.config.language_type try: - # 渲染冲突检测提示词 + # Render conflict detection prompt rendered_prompt = await self.render_evaluate_prompt_func( data, self.conflict_schema, @@ -526,7 +647,7 @@ class ReflectionEngine: messages = [{"role": "user", "content": rendered_prompt}] logging.info(f"提示词长度: {len(rendered_prompt)}") - # 调用 LLM 进行冲突检测 + # Call LLM for conflict detection response = await self.llm_client.response_structured( messages, self.conflict_schema @@ -539,7 +660,7 @@ class ReflectionEngine: logging.error("LLM 冲突检测输出解析失败") return [] - # 标准化返回格式 + # Standardize return format if isinstance(response, BaseModel): return [response.model_dump()] elif hasattr(response, 'dict'): @@ -553,15 +674,17 @@ class ReflectionEngine: async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]: """ - 解决冲突 - - 使用 LLM 对检测到的冲突进行反思和解决。 + Resolve detected conflicts + + Uses LLM to perform reflection and resolution on detected conflicts. + Processes conflicts in parallel for efficiency while respecting concurrency limits. Args: - conflicts: 冲突列表 + conflicts: List of conflicts to be resolved + statement_databasets: Statement database records for context Returns: - List[Any]: 解决方案列表 + List[Any]: List of resolution solutions with reflection results """ if not conflicts: return [] @@ -570,12 +693,12 @@ class ReflectionEngine: baseline = self.config.baseline memory_verify = self.config.memory_verify - # 并行处理每个冲突 + # Process each conflict in parallel async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: - """解决单个冲突""" + """Resolve a single conflict""" async with self._semaphore: try: - # 渲染反思提示词 + # Render reflection prompt rendered_prompt = await self.render_reflexion_prompt_func( [conflict], self.reflexion_schema, @@ -587,7 +710,7 @@ class ReflectionEngine: messages = [{"role": "user", "content": rendered_prompt}] - # 调用 LLM 进行反思 + # Call LLM for reflection response = await self.llm_client.response_structured( messages, self.reflexion_schema @@ -596,7 +719,7 @@ class ReflectionEngine: if not response: return None - # 标准化返回格式 + # Standardize return format if isinstance(response, BaseModel): return response.model_dump() elif hasattr(response, 'dict'): @@ -610,11 +733,11 @@ class ReflectionEngine: logging.warning(f"解决单个冲突失败: {e}") return None - # 并发执行所有冲突解决任务 + # Execute all conflict resolution tasks concurrently tasks = [_resolve_one(conflict) for conflict in conflicts] results = await asyncio.gather(*tasks, return_exceptions=False) - # 过滤掉失败的结果 + # Filter out failed results solved = [r for r in results if r is not None] logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突") @@ -626,15 +749,16 @@ class ReflectionEngine: solved_data: List[Dict[str, Any]] ) -> int: """ - 应用反思结果(更新记忆库) - - 将解决冲突后的记忆更新到 Neo4j 数据库中。 + Apply reflection results (update memory database) + + Updates the Neo4j database with resolved conflicts and reflection results. + Processes the solved data and applies changes to the memory storage system. Args: - solved_data: 解决方案列表 + solved_data: List of resolved conflict solutions with reflection data Returns: - int: 成功更新的记忆数量 + int: Number of successfully updated memory entries """ changes = extract_and_process_changes(solved_data) success_count = await neo4j_data(changes) @@ -642,80 +766,86 @@ class ReflectionEngine: - # 基于时间的反思方法 + # Time-based reflection methods async def time_based_reflection( self, host_id: uuid.UUID, time_period: Optional[str] = None ) -> ReflectionResult: """ - 基于时间的反思 - - 根据时间周期触发反思,检查在指定时间段内的记忆。 + Time-based reflection + + Triggers reflection based on time cycles, checking memories within + specified time periods. Uses the configured iteration period if + no specific time period is provided. Args: - host_id: 主机ID - time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值 + host_id: Host UUID identifier for scoping reflection + time_period: Time period (e.g., "three hours"), uses config value if not provided Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result """ period = time_period or self.config.iteration_period logging.info(f"执行基于时间的反思,周期: {period}") - # 使用标准反思流程 + # Use standard reflection workflow return await self.execute_reflection(host_id) - # 基于事实的反思方法 + # Fact-based reflection methods async def fact_based_reflection( self, host_id: uuid.UUID ) -> ReflectionResult: """ - 基于事实的反思 - - 检测记忆中的事实冲突并解决。 + Fact-based reflection + + Detects and resolves factual conflicts within memories. Analyzes + memory data for inconsistencies and contradictions that need resolution. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping reflection Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result """ logging.info("执行基于事实的反思") - # 使用标准反思流程 + # Use standard reflection workflow return await self.execute_reflection(host_id) - # 综合反思方法 + # Comprehensive reflection methods async def comprehensive_reflection( self, host_id: uuid.UUID ) -> ReflectionResult: """ - 综合反思 - - 整合基于时间和基于事实的反思策略。 + Comprehensive reflection + + Integrates time-based and fact-based reflection strategies based on + the configured baseline. Supports hybrid approaches that combine + multiple reflection methodologies. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping reflection Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result combining + multiple strategies if using hybrid baseline """ logging.info("执行综合反思") - # 根据配置的基线选择反思策略 + # Choose reflection strategy based on configured baseline if self.config.baseline == ReflectionBaseline.TIME: return await self.time_based_reflection(host_id) elif self.config.baseline == ReflectionBaseline.FACT: return await self.fact_based_reflection(host_id) elif self.config.baseline == ReflectionBaseline.HYBRID: - # 混合策略:先执行基于时间的反思,再执行基于事实的反思 + # Hybrid strategy: execute time-based reflection first, then fact-based reflection time_result = await self.time_based_reflection(host_id) fact_result = await self.fact_based_reflection(host_id) - # 合并结果 + # Merge results return ReflectionResult( success=time_result.success and fact_result.success, message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}", diff --git a/api/app/core/memory/utils/data/text_utils.py b/api/app/core/memory/utils/data/text_utils.py index 133990f7..d0b10f97 100644 --- a/api/app/core/memory/utils/data/text_utils.py +++ b/api/app/core/memory/utils/data/text_utils.py @@ -2,9 +2,17 @@ import json def escape_lucene_query(query: str) -> str: - """Escape Lucene special characters in a free-text query. - - This prevents ParseException when using Neo4j full-text procedures. + """ + Escape special characters in Lucene queries + + Prevents ParseException when using Neo4j full-text search procedures. + Escapes all Lucene reserved special characters and operators. + + Args: + query: Original query string + + Returns: + str: Escaped query string safe for Lucene search """ if query is None: return "" @@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str: return s def extract_plain_query(query_input: str) -> str: - """Extract clean, plain-text query from various input forms. - + """ + Extract clean plain-text query from various input forms + + Handles the following cases: - Strips surrounding quotes and whitespace - If input looks like JSON, prefers the 'original' field - - Fallbacks to the raw string when parsing fails + - Falls back to raw string when parsing fails + - Handles dictionary-type input + - Best-effort unescape common escape characters + + Args: + query_input: Query input in various forms (string, dict, etc.) + + Returns: + str: Extracted plain-text query string """ if query_input is None: return "" diff --git a/api/app/core/memory/utils/data/time_utils.py b/api/app/core/memory/utils/data/time_utils.py index c6791dfc..763c642c 100644 --- a/api/app/core/memory/utils/data/time_utils.py +++ b/api/app/core/memory/utils/data/time_utils.py @@ -4,7 +4,13 @@ from datetime import datetime def validate_date_format(date_str: str) -> bool: """ - Validate if the date string is in the format YYYY-MM-DD. + Validate if date string conforms to YYYY-MM-DD format + + Args: + date_str: Date string to validate + + Returns: + bool: True if format is correct, False otherwise """ pattern = r"^\d{4}-\d{1,2}-\d{1,2}$" return bool(re.match(pattern, date_str)) @@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str: def preprocess_date_string(date_str: str) -> str: - """预处理日期字符串,处理特殊格式""" + """ + 预处理日期字符串,处理特殊格式 + + 处理以下特殊格式: + - 年份后直接跟月份没有分隔符的格式(如 "20259/28") + - 无分隔符的纯数字格式(如 "20251028", "251028") + - 混合分隔符,统一为 "-" + + Args: + date_str: 原始日期字符串 + + Returns: + str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM" + """ # 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔) match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str) @@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str: def fallback_parse(date_str: str) -> str: - """备选解析方案""" + """ + 备选日期解析方案 + + 当智能解析失败时,尝试使用预定义的日期格式进行解析。 + 支持多种常见的日期格式,包括: + - YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD + - YYYYMMDD, YYMMDD + - MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY + - DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY + - YYYY-MM, YYYY/MM, YYYY.MM + + Args: + date_str: 待解析的日期字符串 + + Returns: + str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串 + """ # 尝试常见的日期格式[citation:4][citation:5] formats_to_try = [ diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 1a5017eb..26a7390b 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -25,5 +25,6 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 + TIME_SCOPE=5