Fix/memory celery fix (#168)

* refactor(celery): optimize task routing and worker configuration

- Simplify Celery queue configuration with single default 'io_tasks' queue
- Implement task routing strategy separating IO-bound and CPU-bound tasks
- Add Flower monitoring support with task event tracking enabled
- Add summary node search optimization to only retrieve summary nodes
- Clean up unused imports and reorganize import statements for consistency
- Update docker-compose configuration to support multi-queue worker setup

* chore(celery): simplify flower configuration and add gevent dependency

* chore(dependencies): add gevent dependency to requirements

- Add gevent==24.11.1 to api/requirements.txt
- Gevent is required for async worker support in Celery
- Complements existing flower and celery configuration

* refactor(celery): simplify async event loop handling and reorganize task queues

- Replace complex nest_asyncio and manual event loop management with asyncio.run() in read_message_task, write_message_task, regenerate_memory_cache, and workspace_reflection_task
- Rename task queues from io_tasks/cpu_tasks to memory_tasks/document_tasks for better semantic clarity
- Update task routing configuration to reflect new queue names for memory agent tasks and document processing tasks
- Remove redundant exception handling comments and simplify error handling logic
- Update README with improved community support section including GitHub Issues, Pull Requests, Discussions, and WeChat community links
- Simplifies event loop management by leveraging asyncio.run() which handles loop creation and cleanup automatically, reducing code complexity and potential race conditions
This commit is contained in:
Ke Sun
2026-01-21 17:58:46 +08:00
committed by GitHub
parent 37ef497f4c
commit c24fb73147
12 changed files with 254 additions and 259 deletions

View File

@@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface.
## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Acknowledgements & Community
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
## Community & Support
Join our community to ask questions, share your work, and connect with fellow developers.
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -1,4 +1,5 @@
import os
import platform
from datetime import timedelta
from urllib.parse import quote
@@ -14,27 +15,12 @@ celery_app = Celery(
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
@@ -52,36 +38,47 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# 结果过期时间
result_expires=3600, # 结果保存 1 小时
result_expires=3600, # 结果保存1小时
# 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
},
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
# 构建定时任务配置
beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,

View File

@@ -4,12 +4,11 @@ import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
search_params = {
"group_id": group_id,
"question": data,
"return_raw_results": True
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
}
try:

View File

@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
start_time = time.time()
try:
# First check if model exists at all (without tenant filtering)
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
# Then check with tenant filtering
# OPTIMIZED: Single query with tenant filter
# We'll check tenant mismatch in the error handling
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
# Model not found with tenant filter - check if it exists without filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
if model_without_tenant:
# Model exists but belongs to different tenant
logger.warning(
@@ -208,8 +209,11 @@ def validate_embedding_model(
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that embedding model is available and return its UUID.
) -> tuple[UUID, str]:
"""Validate that embedding model is available and return its UUID and name.
Returns:
Tuple of (embedding_uuid, embedding_name)
Raises:
InvalidConfigError: If embedding_id is not provided or invalid
@@ -225,14 +229,19 @@ def validate_embedding_model(
workspace_id=workspace_id
)
embedding_uuid, _ = validate_and_resolve_model_id(
embedding_uuid, embedding_name = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
print(100*'-')
print(embedding_uuid)
print(_)
print(100*'-')
logger.debug(
"Embedding model validated",
extra={
"embedding_uuid": str(embedding_uuid),
"embedding_name": embedding_name,
"config_id": config_id
}
)
if embedding_uuid is None:
raise InvalidConfigError(
@@ -243,7 +252,7 @@ def validate_embedding_model(
workspace_id=workspace_id
)
return embedding_uuid
return embedding_uuid, embedding_name
def validate_llm_model(

View File

@@ -305,12 +305,19 @@ async def search_graph(
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
else:
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text cannot be empty")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
logger.debug(f"Temporal search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id cannot be empty")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id cannot be empty")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值
results = {"statements": statements}

View File

@@ -10,11 +10,6 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
# result = task_service.get_task_memory_read_result(task.id)
# status = result.get("status")
# logger.info(f"读取任务状态:{status}")
finally:
db.close()

View File

@@ -10,15 +10,17 @@ import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
import redis
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
@@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -404,6 +407,7 @@ class MemoryAgentService:
import time
start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
# Resolve config_id if None using end_user's connected config
if config_id is None:
@@ -427,13 +431,15 @@ class MemoryAgentService:
audit_logger = None
config_load_start = time.time()
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
@@ -457,6 +463,7 @@ class MemoryAgentService:
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
@@ -513,6 +520,9 @@ class MemoryAgentService:
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
@@ -570,6 +580,8 @@ class MemoryAgentService:
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -587,7 +599,8 @@ class MemoryAgentService:
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
total_time = time.time() - start_time
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(

View File

@@ -125,7 +125,11 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
# Step 1: Get config and workspace
db_query_start = time.time()
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -144,16 +148,20 @@ class MemoryConfigService:
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
# Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Resolve LLM model
# Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
@@ -163,8 +171,11 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Resolve optional rerank model
# Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
@@ -177,16 +188,12 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id,
)
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Note: embedding_name is now returned from validate_embedding_model above
# No need for redundant query!
# Create immutable MemoryConfig object
config = MemoryConfig(

View File

@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
return {
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -600,53 +564,53 @@ def reflection_timer_task() -> None:
"""
reflection_engine()
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]:
"""Call read_service and write latest status to Redis.
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]:
# """Call read_service and write latest status to Redis.
Returns status data dict that gets written to Redis.
"""
client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
)
try:
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
payload = {
"user_id": "健康检查",
"apply_id": "健康检查",
"group_id": "健康检查",
"message": "你好",
"history": [],
"search_switch": "2",
}
resp = requests.post(api_url, json=payload, timeout=15)
ok = resp.status_code == 200
status = "Success" if ok else "Fail"
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
error = "" if ok else resp.text
code = 0 if ok else 500
except Exception as e:
status = "Fail"
msg = "接口请求失败"
error = str(e)
code = 500
# Returns status data dict that gets written to Redis.
# """
# client = redis.Redis(
# host=settings.REDIS_HOST,
# port=settings.REDIS_PORT,
# db=settings.REDIS_DB,
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
# )
# try:
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
# payload = {
# "user_id": "健康检查",
# "apply_id": "健康检查",
# "group_id": "健康检查",
# "message": "你好",
# "history": [],
# "search_switch": "2",
# }
# resp = requests.post(api_url, json=payload, timeout=15)
# ok = resp.status_code == 200
# status = "Success" if ok else "Fail"
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
# error = "" if ok else resp.text
# code = 0 if ok else 500
# except Exception as e:
# status = "Fail"
# msg = "接口请求失败"
# error = str(e)
# code = 500
data = {
"status": status,
"msg": msg,
"error": error,
"code": str(code),
"time": str(int(time.time())),
}
# data = {
# "status": status,
# "msg": msg,
# "error": error,
# "code": str(code),
# "time": str(int(time.time())),
# }
client.hset("memsci:health:read_service", mapping=data)
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
# client.hset("memsci:health:read_service", mapping=data)
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
return data
# return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
@@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration
}
# 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()
return asyncio.run(_run())

View File

@@ -7,10 +7,6 @@ services:
- "8002:8000"
env_file:
- .env
environment:
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
@@ -19,20 +15,53 @@ services:
networks:
- default
- celery
depends_on:
- worker-memory
- worker-document
# Celery worker
worker:
# Memory worker - Memory read/write tasks (threads pool for asyncio)
worker-memory:
image: redbear-mem-open:latest
container_name: worker
container_name: worker-memory
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
restart: unless-stopped
networks:
- celery
# Document worker - Document parsing tasks (prefork for CPU-bound)
worker-document:
image: redbear-mem-open:latest
container_name: worker-document
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
container_name: celery-beat
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app beat --loglevel=info
restart: unless-stopped
networks:
- celery
depends_on:
- worker-memory
networks:
celery:

View File

@@ -13,6 +13,7 @@ dependencies = [
"bcrypt==5.0.0",
"billiard==4.2.2",
"celery==5.5.3",
"flower==2.0.1",
"cffi==2.0.0",
"click==8.3.0",
"click-didyoumean==0.3.1",
@@ -138,6 +139,7 @@ dependencies = [
"python-calamine>=0.4.0",
"xlrd==2.0.2",
"deprecated>=1.3.1",
"flower>=2.0.1",
]
[tool.pytest.ini_options]

View File

@@ -6,6 +6,7 @@ async-timeout==5.0.1
bcrypt==5.0.0
billiard==4.2.2
celery==5.5.3
flower==2.0.1
cffi==2.0.0
click==8.3.0
click-didyoumean==0.3.1