diff --git a/README.md b/README.md index 7d26f7f7..32a779d2 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface. ## License This project is licensed under the Apache License 2.0. For details, see the LICENSE file. -## Acknowledgements & Community -- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions. -- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines. -- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file +## Community & Support + +Join our community to ask questions, share your work, and connect with fellow developers. + +- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues). +- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls). +- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions). +- **WeChat**: Scan the QR code below to join our WeChat community group. +- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 85ad0643..185d746c 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,4 +1,5 @@ import os +import platform from datetime import timedelta from urllib.parse import quote @@ -14,27 +15,12 @@ celery_app = Celery( backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", ) -# 配置使用本地队列,避免与远程 worker 冲突 -celery_app.conf.task_default_queue = 'localhost_test_wyl' -celery_app.conf.task_default_exchange = 'localhost_test_wyl' -celery_app.conf.task_default_routing_key = 'localhost_test_wyl' +# Default queue for unrouted tasks +celery_app.conf.task_default_queue = 'memory_tasks' # macOS 兼容性配置 -import platform - -if platform.system() == 'Darwin': # macOS - # 设置环境变量解决 fork 问题 +if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') - - # 使用 solo 池避免多进程问题 - celery_app.conf.worker_pool = 'solo' - - # 设置唯一的节点名称 - import socket - import time - hostname = socket.gethostname() - timestamp = int(time.time()) - celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}" # Celery 配置 celery_app.conf.update( @@ -52,36 +38,47 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=30 * 60, # 30 分钟硬超时 - task_soft_time_limit=25 * 60, # 25 分钟软超时 + task_time_limit=1800, # 30分钟硬超时 + task_soft_time_limit=1500, # 25分钟软超时 - # Worker 设置 - 针对 macOS 优化 - worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 - worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏 - worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker + # Worker 设置 (per-worker settings are in docker-compose command line) + worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution # 结果过期时间 - result_expires=3600, # 结果保存 1 小时 + result_expires=3600, # 结果保存1小时 # 任务确认设置 - task_acks_late=True, # 任务完成后才确认,避免任务丢失 - worker_disable_rate_limits=True, # 禁用速率限制 + task_acks_late=True, + task_reject_on_worker_lost=True, + worker_disable_rate_limits=True, - # 任务路由(可选,用于不同队列) - # task_routes={ - # 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, - # 'app.core.memory.agent.read_message': {'queue': 'memory_processing'}, - # 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, - # 'tasks.process_item': {'queue': 'default'}, - # }, + # FLower setting + worker_send_task_events=True, + task_send_sent_event=True, + + # task routing + task_routes={ + # Memory tasks → memory_tasks queue (threads worker) + 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + + # Document tasks → document_tasks queue (prefork worker) + 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + + # Beat/periodic tasks → document_tasks queue (prefork worker) + 'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'}, + 'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'}, + 'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'}, + 'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'}, + }, ) # 自动发现任务模块 celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) -health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME @@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘 # 构建定时任务配置 beat_schedule_config = { - - # "check-read-service": { - # "task": "app.core.memory.agent.health.check_read_service", - # "schedule": health_schedule, - # "args": (), - # }, "run-workspace-reflection": { "task": "app.tasks.workspace_reflection_task", "schedule": workspace_reflection_schedule, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0d0b57b0..44f89c6a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -4,12 +4,11 @@ import os import time from app.core.logging_config import get_agent_logger, log_time -from app.db import get_db - from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, SummaryResponse, ) +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.search_service import SearchService from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, @@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) @@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState: search_params = { "group_id": group_id, "question": data, - "return_raw_results": True + "return_raw_results": True, + "include": ["summaries"] # Only search summary nodes for faster performance } try: diff --git a/api/app/core/validators/memory_config_validators.py b/api/app/core/validators/memory_config_validators.py index 6ccf3ddb..333572e6 100644 --- a/api/app/core/validators/memory_config_validators.py +++ b/api/app/core/validators/memory_config_validators.py @@ -89,14 +89,15 @@ def validate_model_exists_and_active( start_time = time.time() try: - # First check if model exists at all (without tenant filtering) - model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) - - # Then check with tenant filtering + # OPTIMIZED: Single query with tenant filter + # We'll check tenant mismatch in the error handling model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) elapsed_ms = (time.time() - start_time) * 1000 if not model: + # Model not found with tenant filter - check if it exists without filter + model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) + if model_without_tenant: # Model exists but belongs to different tenant logger.warning( @@ -208,8 +209,11 @@ def validate_embedding_model( db: Session, tenant_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None -) -> UUID: - """Validate that embedding model is available and return its UUID. +) -> tuple[UUID, str]: + """Validate that embedding model is available and return its UUID and name. + + Returns: + Tuple of (embedding_uuid, embedding_name) Raises: InvalidConfigError: If embedding_id is not provided or invalid @@ -225,14 +229,19 @@ def validate_embedding_model( workspace_id=workspace_id ) - embedding_uuid, _ = validate_and_resolve_model_id( + embedding_uuid, embedding_name = validate_and_resolve_model_id( embedding_id, "embedding", db, tenant_id, required=True, config_id=config_id, workspace_id=workspace_id ) - print(100*'-') - print(embedding_uuid) - print(_) - print(100*'-') + + logger.debug( + "Embedding model validated", + extra={ + "embedding_uuid": str(embedding_uuid), + "embedding_name": embedding_name, + "config_id": config_id + } + ) if embedding_uuid is None: raise InvalidConfigError( @@ -243,7 +252,7 @@ def validate_embedding_model( workspace_id=workspace_id ) - return embedding_uuid + return embedding_uuid, embedding_name def validate_llm_model( diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 0b6a27c6..6f5764b4 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -305,12 +305,19 @@ async def search_graph( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) + if needs_activation_update: + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + return results @@ -339,7 +346,7 @@ async def search_graph_by_embedding( embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") + logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: return {"statements": [], "chunks": [], "entities": [], "summaries": []} @@ -393,7 +400,7 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { @@ -417,14 +424,23 @@ async def search_graph_by_embedding( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - update_start = time.time() - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) - update_time = time.time() - update_start - print(f"[PERF] Activation value updates took: {update_time:.4f}s") + + if needs_activation_update: + update_start = time.time() + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + update_time = time.time() - update_start + logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") + else: + logger.info(f"[PERF] Skipping activation updates (only summaries)") return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 @@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - print(f"query_text不能为空") + logger.warning(f"query_text cannot be empty") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal( invalid_date=invalid_date, limit=limit, ) - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal keyword search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -594,9 +610,9 @@ async def search_graph_by_temporal( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") + logger.debug(f"Temporal search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -623,7 +639,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - print(f"dialog_id不能为空") + logger.warning(f"dialog_id cannot be empty") return {"dialogues": []} dialogues = await connector.execute_query( @@ -642,7 +658,7 @@ async def search_graph_by_chunk_id( limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - print(f"chunk_id不能为空") + logger.warning(f"chunk_id cannot be empty") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -679,9 +695,9 @@ async def search_graph_by_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -719,9 +735,9 @@ async def search_graph_by_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -759,9 +775,9 @@ async def search_graph_g_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -799,9 +815,9 @@ async def search_graph_g_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -839,9 +855,9 @@ async def search_graph_l_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -879,9 +895,9 @@ async def search_graph_l_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 50934226..46bda5f6 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,11 +10,6 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session - from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger from app.services.tool_service import ToolService +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "app.core.memory.agent.read_message", args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] ) - result = task_service.get_task_memory_read_result(task.id) - status = result.get("status") - logger.info(f"读取任务状态:{status}") + # result = task_service.get_task_memory_read_result(task.id) + # status = result.get("status") + # logger.info(f"读取任务状态:{status}") finally: db.close() diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 692e9a9a..6748d6c7 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -10,15 +10,17 @@ import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -import redis -from langchain_core.messages import HumanMessage +import redis from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results +from app.core.memory.agent.utils.messages_tools import ( + merge_multiple_search_results, + reorder_output_results, +) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags @@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -404,6 +407,7 @@ class MemoryAgentService: import time start_time = time.time() + logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") # Resolve config_id if None using end_user's connected config if config_id is None: @@ -427,13 +431,15 @@ class MemoryAgentService: audit_logger = None + config_load_start = time.time() try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + config_load_time = time.time() - config_load_start + logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) @@ -457,6 +463,7 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow + graph_exec_start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} @@ -513,6 +520,9 @@ class MemoryAgentService: if summary_n and summary_n != [] and summary_n != {}: _intermediate_outputs.append(summary_n) + graph_exec_time = time.time() - graph_exec_start + logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] optimized_outputs = merge_multiple_search_results(_intermediate_outputs) @@ -570,6 +580,8 @@ class MemoryAgentService: logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation + total_time = time.time() - start_time + logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -587,7 +599,8 @@ class MemoryAgentService: except Exception as e: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" - logger.error(error_msg) + total_time = time.time() - start_time + logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 09e980a0..0099eb18 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -125,7 +125,11 @@ class MemoryConfigService: try: validated_config_id = _validate_config_id(config_id) + # Step 1: Get config and workspace + db_query_start = time.time() result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) + db_query_time = time.time() - db_query_start + logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") if not result: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -144,16 +148,20 @@ class MemoryConfigService: memory_config, workspace = result - # Validate embedding model - embedding_uuid = validate_embedding_model( + # Step 2: Validate embedding model (returns both UUID and name) + embed_start = time.time() + embedding_uuid, embedding_name = validate_embedding_model( validated_config_id, memory_config.embedding_id, self.db, workspace.tenant_id, workspace.id, ) + embed_time = time.time() - embed_start + logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s") - # Resolve LLM model + # Step 3: Resolve LLM model + llm_start = time.time() llm_uuid, llm_name = validate_and_resolve_model_id( memory_config.llm_id, "llm", @@ -163,8 +171,11 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + llm_time = time.time() - llm_start + logger.info(f"[PERF] LLM validation: {llm_time:.4f}s") - # Resolve optional rerank model + # Step 4: Resolve optional rerank model + rerank_start = time.time() rerank_uuid = None rerank_name = None if memory_config.rerank_id: @@ -177,16 +188,12 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + rerank_time = time.time() - rerank_start + if memory_config.rerank_id: + logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") - # Get embedding model name - embedding_name, _ = validate_model_exists_and_active( - embedding_uuid, - "embedding", - self.db, - workspace.tenant_id, - config_id=validated_config_id, - workspace_id=workspace.id, - ) + # Note: embedding_name is now returned from validate_embedding_model above + # No need for redundant query! # Create immutable MemoryConfig object config = MemoryConfig( diff --git a/api/app/tasks.py b/api/app/tasks.py index e375de35..fa9d1fdf 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time return { @@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") @@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -600,53 +564,53 @@ def reflection_timer_task() -> None: """ reflection_engine() - -@celery_app.task(name="app.core.memory.agent.health.check_read_service") -def check_read_service_task() -> Dict[str, str]: - """Call read_service and write latest status to Redis. +# unused task +# @celery_app.task(name="app.core.memory.agent.health.check_read_service") +# def check_read_service_task() -> Dict[str, str]: +# """Call read_service and write latest status to Redis. - Returns status data dict that gets written to Redis. - """ - client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None - ) - try: - api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" - payload = { - "user_id": "健康检查", - "apply_id": "健康检查", - "group_id": "健康检查", - "message": "你好", - "history": [], - "search_switch": "2", - } - resp = requests.post(api_url, json=payload, timeout=15) - ok = resp.status_code == 200 - status = "Success" if ok else "Fail" - msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" - error = "" if ok else resp.text - code = 0 if ok else 500 - except Exception as e: - status = "Fail" - msg = "接口请求失败" - error = str(e) - code = 500 +# Returns status data dict that gets written to Redis. +# """ +# client = redis.Redis( +# host=settings.REDIS_HOST, +# port=settings.REDIS_PORT, +# db=settings.REDIS_DB, +# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None +# ) +# try: +# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" +# payload = { +# "user_id": "健康检查", +# "apply_id": "健康检查", +# "group_id": "健康检查", +# "message": "你好", +# "history": [], +# "search_switch": "2", +# } +# resp = requests.post(api_url, json=payload, timeout=15) +# ok = resp.status_code == 200 +# status = "Success" if ok else "Fail" +# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" +# error = "" if ok else resp.text +# code = 0 if ok else 500 +# except Exception as e: +# status = "Fail" +# msg = "接口请求失败" +# error = str(e) +# code = 500 - data = { - "status": status, - "msg": msg, - "error": error, - "code": str(code), - "time": str(int(time.time())), - } +# data = { +# "status": status, +# "msg": msg, +# "error": error, +# "code": str(code), +# "time": str(int(time.time())), +# } - client.hset("memsci:health:read_service", mapping=data) - client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) +# client.hset("memsci:health:read_service", mapping=data) +# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) - return data +# return data @celery_app.task(name="app.controllers.memory_storage_controller.search_all") @@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str "duration_seconds": duration } - # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_run()) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 8bc19f3a..a7337689 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -7,10 +7,6 @@ services: - "8002:8000" env_file: - .env - environment: - - SERVER_IP=0.0.0.0 - # 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位 - # - MCP_SERVER_URL= volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro @@ -19,20 +15,53 @@ services: networks: - default - celery + depends_on: + - worker-memory + - worker-document - # Celery worker - worker: + # Memory worker - Memory read/write tasks (threads pool for asyncio) + worker-memory: image: redbear-mem-open:latest - container_name: worker + container_name: worker-memory env_file: - .env volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro - command: celery -A app.celery_worker.celery_app worker --loglevel=info + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h restart: unless-stopped networks: - celery + # Document worker - Document parsing tasks (prefork for CPU-bound) + worker-document: + image: redbear-mem-open:latest + container_name: worker-document + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h + restart: unless-stopped + networks: + - celery + + # Celery Beat - scheduler + beat: + image: redbear-mem-open:latest + container_name: celery-beat + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app beat --loglevel=info + restart: unless-stopped + networks: + - celery + depends_on: + - worker-memory + networks: celery: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6da684de..414ba372 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "bcrypt==5.0.0", "billiard==4.2.2", "celery==5.5.3", + "flower==2.0.1", "cffi==2.0.0", "click==8.3.0", "click-didyoumean==0.3.1", @@ -138,6 +139,7 @@ dependencies = [ "python-calamine>=0.4.0", "xlrd==2.0.2", "deprecated>=1.3.1", + "flower>=2.0.1", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 99252e09..444a194b 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -6,6 +6,7 @@ async-timeout==5.0.1 bcrypt==5.0.0 billiard==4.2.2 celery==5.5.3 +flower==2.0.1 cffi==2.0.0 click==8.3.0 click-didyoumean==0.3.1