Fix/memory celery fix (#168)
* refactor(celery): optimize task routing and worker configuration - Simplify Celery queue configuration with single default 'io_tasks' queue - Implement task routing strategy separating IO-bound and CPU-bound tasks - Add Flower monitoring support with task event tracking enabled - Add summary node search optimization to only retrieve summary nodes - Clean up unused imports and reorganize import statements for consistency - Update docker-compose configuration to support multi-queue worker setup * chore(celery): simplify flower configuration and add gevent dependency * chore(dependencies): add gevent dependency to requirements - Add gevent==24.11.1 to api/requirements.txt - Gevent is required for async worker support in Celery - Complements existing flower and celery configuration * refactor(celery): simplify async event loop handling and reorganize task queues - Replace complex nest_asyncio and manual event loop management with asyncio.run() in read_message_task, write_message_task, regenerate_memory_cache, and workspace_reflection_task - Rename task queues from io_tasks/cpu_tasks to memory_tasks/document_tasks for better semantic clarity - Update task routing configuration to reflect new queue names for memory agent tasks and document processing tasks - Remove redundant exception handling comments and simplify error handling logic - Update README with improved community support section including GitHub Issues, Pull Requests, Discussions, and WeChat community links - Simplifies event loop management by leveraging asyncio.run() which handles loop creation and cleanup automatically, reducing code complexity and potential race conditions
This commit is contained in:
13
README.md
13
README.md
@@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface.
|
|||||||
## License
|
## License
|
||||||
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
||||||
|
|
||||||
## Acknowledgements & Community
|
## Community & Support
|
||||||
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
|
|
||||||
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
|
Join our community to ask questions, share your work, and connect with fellow developers.
|
||||||
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
|
||||||
|
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues).
|
||||||
|
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls).
|
||||||
|
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions).
|
||||||
|
- **WeChat**: Scan the QR code below to join our WeChat community group.
|
||||||
|
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@@ -14,27 +15,12 @@ celery_app = Celery(
|
|||||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 配置使用本地队列,避免与远程 worker 冲突
|
# Default queue for unrouted tasks
|
||||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||||
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
|
||||||
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
|
||||||
|
|
||||||
# macOS 兼容性配置
|
# macOS 兼容性配置
|
||||||
import platform
|
if platform.system() == 'Darwin':
|
||||||
|
|
||||||
if platform.system() == 'Darwin': # macOS
|
|
||||||
# 设置环境变量解决 fork 问题
|
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 使用 solo 池避免多进程问题
|
|
||||||
celery_app.conf.worker_pool = 'solo'
|
|
||||||
|
|
||||||
# 设置唯一的节点名称
|
|
||||||
import socket
|
|
||||||
import time
|
|
||||||
hostname = socket.gethostname()
|
|
||||||
timestamp = int(time.time())
|
|
||||||
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
|
||||||
|
|
||||||
# Celery 配置
|
# Celery 配置
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
@@ -52,36 +38,47 @@ celery_app.conf.update(
|
|||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|
||||||
# 超时设置
|
# 超时设置
|
||||||
task_time_limit=30 * 60, # 30 分钟硬超时
|
task_time_limit=1800, # 30分钟硬超时
|
||||||
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
task_soft_time_limit=1500, # 25分钟软超时
|
||||||
|
|
||||||
# Worker 设置 - 针对 macOS 优化
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
|
||||||
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存 1 小时
|
result_expires=3600, # 结果保存1小时
|
||||||
|
|
||||||
# 任务确认设置
|
# 任务确认设置
|
||||||
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
task_acks_late=True,
|
||||||
worker_disable_rate_limits=True, # 禁用速率限制
|
task_reject_on_worker_lost=True,
|
||||||
|
worker_disable_rate_limits=True,
|
||||||
|
|
||||||
# 任务路由(可选,用于不同队列)
|
# FLower setting
|
||||||
# task_routes={
|
worker_send_task_events=True,
|
||||||
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
task_send_sent_event=True,
|
||||||
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
|
||||||
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
# task routing
|
||||||
# 'tasks.process_item': {'queue': 'default'},
|
task_routes={
|
||||||
# },
|
# Memory tasks → memory_tasks queue (threads worker)
|
||||||
|
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
|
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||||
|
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||||
|
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||||
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||||
|
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 自动发现任务模块
|
# 自动发现任务模块
|
||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
|
||||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
|
||||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||||
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
|
|||||||
|
|
||||||
# 构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
|
|
||||||
# "check-read-service": {
|
|
||||||
# "task": "app.core.memory.agent.health.check_read_service",
|
|
||||||
# "schedule": health_schedule,
|
|
||||||
# "args": (),
|
|
||||||
# },
|
|
||||||
"run-workspace-reflection": {
|
"run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
"task": "app.tasks.workspace_reflection_task",
|
||||||
"schedule": workspace_reflection_schedule,
|
"schedule": workspace_reflection_schedule,
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
SummaryResponse,
|
SummaryResponse,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
search_params = {
|
search_params = {
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True
|
"return_raw_results": True,
|
||||||
|
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First check if model exists at all (without tenant filtering)
|
# OPTIMIZED: Single query with tenant filter
|
||||||
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
# We'll check tenant mismatch in the error handling
|
||||||
|
|
||||||
# Then check with tenant filtering
|
|
||||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
|
# Model not found with tenant filter - check if it exists without filter
|
||||||
|
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
||||||
|
|
||||||
if model_without_tenant:
|
if model_without_tenant:
|
||||||
# Model exists but belongs to different tenant
|
# Model exists but belongs to different tenant
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -208,8 +209,11 @@ def validate_embedding_model(
|
|||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> UUID:
|
) -> tuple[UUID, str]:
|
||||||
"""Validate that embedding model is available and return its UUID.
|
"""Validate that embedding model is available and return its UUID and name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (embedding_uuid, embedding_name)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
InvalidConfigError: If embedding_id is not provided or invalid
|
InvalidConfigError: If embedding_id is not provided or invalid
|
||||||
@@ -225,14 +229,19 @@ def validate_embedding_model(
|
|||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_uuid, _ = validate_and_resolve_model_id(
|
embedding_uuid, embedding_name = validate_and_resolve_model_id(
|
||||||
embedding_id, "embedding", db, tenant_id, required=True,
|
embedding_id, "embedding", db, tenant_id, required=True,
|
||||||
config_id=config_id, workspace_id=workspace_id
|
config_id=config_id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
print(100*'-')
|
|
||||||
print(embedding_uuid)
|
logger.debug(
|
||||||
print(_)
|
"Embedding model validated",
|
||||||
print(100*'-')
|
extra={
|
||||||
|
"embedding_uuid": str(embedding_uuid),
|
||||||
|
"embedding_name": embedding_name,
|
||||||
|
"config_id": config_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if embedding_uuid is None:
|
if embedding_uuid is None:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
@@ -243,7 +252,7 @@ def validate_embedding_model(
|
|||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return embedding_uuid
|
return embedding_uuid, embedding_name
|
||||||
|
|
||||||
|
|
||||||
def validate_llm_model(
|
def validate_llm_model(
|
||||||
|
|||||||
@@ -305,12 +305,19 @@ async def search_graph(
|
|||||||
results[key] = _deduplicate_results(results[key])
|
results[key] = _deduplicate_results(results[key])
|
||||||
|
|
||||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||||
results = await _update_search_results_activation(
|
# Skip activation updates if only searching summaries (optimization)
|
||||||
connector=connector,
|
needs_activation_update = any(
|
||||||
results=results,
|
key in include and key in results and results[key]
|
||||||
group_id=group_id
|
for key in ['statements', 'entities', 'chunks']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if needs_activation_update:
|
||||||
|
results = await _update_search_results_activation(
|
||||||
|
connector=connector,
|
||||||
|
results=results,
|
||||||
|
group_id=group_id
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
|
|||||||
embed_start = time.time()
|
embed_start = time.time()
|
||||||
embeddings = await embedder_client.response([query_text])
|
embeddings = await embedder_client.response([query_text])
|
||||||
embed_time = time.time() - embed_start
|
embed_time = time.time() - embed_start
|
||||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||||
|
|
||||||
if not embeddings or not embeddings[0]:
|
if not embeddings or not embeddings[0]:
|
||||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||||
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
|
|||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
query_time = time.time() - query_start
|
query_time = time.time() - query_start
|
||||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||||
|
|
||||||
# Build results dictionary
|
# Build results dictionary
|
||||||
results: Dict[str, List[Dict[str, Any]]] = {
|
results: Dict[str, List[Dict[str, Any]]] = {
|
||||||
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
|
|||||||
results[key] = _deduplicate_results(results[key])
|
results[key] = _deduplicate_results(results[key])
|
||||||
|
|
||||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||||
update_start = time.time()
|
# Skip activation updates if only searching summaries (optimization)
|
||||||
results = await _update_search_results_activation(
|
needs_activation_update = any(
|
||||||
connector=connector,
|
key in include and key in results and results[key]
|
||||||
results=results,
|
for key in ['statements', 'entities', 'chunks']
|
||||||
group_id=group_id
|
|
||||||
)
|
)
|
||||||
update_time = time.time() - update_start
|
|
||||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
if needs_activation_update:
|
||||||
|
update_start = time.time()
|
||||||
|
results = await _update_search_results_activation(
|
||||||
|
connector=connector,
|
||||||
|
results=results,
|
||||||
|
group_id=group_id
|
||||||
|
)
|
||||||
|
update_time = time.time() - update_start
|
||||||
|
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||||
|
else:
|
||||||
|
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||||
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
if not query_text:
|
if not query_text:
|
||||||
print(f"query_text不能为空")
|
logger.warning(f"query_text cannot be empty")
|
||||||
return {"statements": []}
|
return {"statements": []}
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||||
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
invalid_date=invalid_date,
|
invalid_date=invalid_date,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Temporal search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
|
|||||||
- Returns up to 'limit' dialogues
|
- Returns up to 'limit' dialogues
|
||||||
"""
|
"""
|
||||||
if not dialog_id:
|
if not dialog_id:
|
||||||
print(f"dialog_id不能为空")
|
logger.warning(f"dialog_id cannot be empty")
|
||||||
return {"dialogues": []}
|
return {"dialogues": []}
|
||||||
|
|
||||||
dialogues = await connector.execute_query(
|
dialogues = await connector.execute_query(
|
||||||
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
|
|||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
if not chunk_id:
|
if not chunk_id:
|
||||||
print(f"chunk_id不能为空")
|
logger.warning(f"chunk_id cannot be empty")
|
||||||
return {"chunks": []}
|
return {"chunks": []}
|
||||||
chunks = await connector.execute_query(
|
chunks = await connector.execute_query(
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
|
||||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"Search results: {len(statements)} statements found")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
|
|||||||
@@ -10,11 +10,6 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.tools import tool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
|
|||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
|
from langchain.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
class KnowledgeRetrievalInput(BaseModel):
|
class KnowledgeRetrievalInput(BaseModel):
|
||||||
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"app.core.memory.agent.read_message",
|
"app.core.memory.agent.read_message",
|
||||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
result = task_service.get_task_memory_read_result(task.id)
|
# result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
# status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
# logger.info(f"读取任务状态:{status}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -10,15 +10,17 @@ import re
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
import redis
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
|
import redis
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
from app.core.memory.agent.utils.messages_tools import (
|
||||||
|
merge_multiple_search_results,
|
||||||
|
reorder_output_results,
|
||||||
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
@@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService
|
|||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -404,6 +407,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
|
||||||
|
|
||||||
# Resolve config_id if None using end_user's connected config
|
# Resolve config_id if None using end_user's connected config
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
@@ -427,13 +431,15 @@ class MemoryAgentService:
|
|||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
|
|
||||||
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
config_load_time = time.time() - config_load_start
|
||||||
|
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -457,6 +463,7 @@ class MemoryAgentService:
|
|||||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||||
|
|
||||||
# Step 3: Initialize MCP client and execute read workflow
|
# Step 3: Initialize MCP client and execute read workflow
|
||||||
|
graph_exec_start = time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": group_id}}
|
||||||
@@ -513,6 +520,9 @@ class MemoryAgentService:
|
|||||||
if summary_n and summary_n != [] and summary_n != {}:
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
_intermediate_outputs.append(summary_n)
|
_intermediate_outputs.append(summary_n)
|
||||||
|
|
||||||
|
graph_exec_time = time.time() - graph_exec_start
|
||||||
|
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||||||
|
|
||||||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||||
|
|
||||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||||
@@ -570,6 +580,8 @@ class MemoryAgentService:
|
|||||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||||
|
|
||||||
# Log successful operation
|
# Log successful operation
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
@@ -587,7 +599,8 @@ class MemoryAgentService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Read operation failed: {str(e)}"
|
error_msg = f"Read operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
total_time = time.time() - start_time
|
||||||
|
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
|
|||||||
@@ -125,7 +125,11 @@ class MemoryConfigService:
|
|||||||
try:
|
try:
|
||||||
validated_config_id = _validate_config_id(config_id)
|
validated_config_id = _validate_config_id(config_id)
|
||||||
|
|
||||||
|
# Step 1: Get config and workspace
|
||||||
|
db_query_start = time.time()
|
||||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||||
|
db_query_time = time.time() - db_query_start
|
||||||
|
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||||
if not result:
|
if not result:
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
config_logger.error(
|
config_logger.error(
|
||||||
@@ -144,16 +148,20 @@ class MemoryConfigService:
|
|||||||
|
|
||||||
memory_config, workspace = result
|
memory_config, workspace = result
|
||||||
|
|
||||||
# Validate embedding model
|
# Step 2: Validate embedding model (returns both UUID and name)
|
||||||
embedding_uuid = validate_embedding_model(
|
embed_start = time.time()
|
||||||
|
embedding_uuid, embedding_name = validate_embedding_model(
|
||||||
validated_config_id,
|
validated_config_id,
|
||||||
memory_config.embedding_id,
|
memory_config.embedding_id,
|
||||||
self.db,
|
self.db,
|
||||||
workspace.tenant_id,
|
workspace.tenant_id,
|
||||||
workspace.id,
|
workspace.id,
|
||||||
)
|
)
|
||||||
|
embed_time = time.time() - embed_start
|
||||||
|
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
|
||||||
|
|
||||||
# Resolve LLM model
|
# Step 3: Resolve LLM model
|
||||||
|
llm_start = time.time()
|
||||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||||
memory_config.llm_id,
|
memory_config.llm_id,
|
||||||
"llm",
|
"llm",
|
||||||
@@ -163,8 +171,11 @@ class MemoryConfigService:
|
|||||||
config_id=validated_config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
|
llm_time = time.time() - llm_start
|
||||||
|
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
|
||||||
|
|
||||||
# Resolve optional rerank model
|
# Step 4: Resolve optional rerank model
|
||||||
|
rerank_start = time.time()
|
||||||
rerank_uuid = None
|
rerank_uuid = None
|
||||||
rerank_name = None
|
rerank_name = None
|
||||||
if memory_config.rerank_id:
|
if memory_config.rerank_id:
|
||||||
@@ -177,16 +188,12 @@ class MemoryConfigService:
|
|||||||
config_id=validated_config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
|
rerank_time = time.time() - rerank_start
|
||||||
|
if memory_config.rerank_id:
|
||||||
|
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||||
|
|
||||||
# Get embedding model name
|
# Note: embedding_name is now returned from validate_embedding_model above
|
||||||
embedding_name, _ = validate_model_exists_and_active(
|
# No need for redundant query!
|
||||||
embedding_uuid,
|
|
||||||
"embedding",
|
|
||||||
self.db,
|
|
||||||
workspace.tenant_id,
|
|
||||||
config_id=validated_config_id,
|
|
||||||
workspace_id=workspace.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create immutable MemoryConfig object
|
# Create immutable MemoryConfig object
|
||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
|
|||||||
173
api/app/tasks.py
173
api/app/tasks.py
@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
result = asyncio.run(_run())
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_closed():
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
|||||||
}
|
}
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
# Handle ExceptionGroup from TaskGroup
|
|
||||||
if hasattr(e, 'exceptions'):
|
if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
detailed_error = "; ".join(error_messages)
|
||||||
@@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
result = asyncio.run(_run())
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_closed():
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||||
@@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
|||||||
}
|
}
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
# Handle ExceptionGroup from TaskGroup
|
|
||||||
if hasattr(e, 'exceptions'):
|
if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
detailed_error = "; ".join(error_messages)
|
||||||
@@ -600,53 +564,53 @@ def reflection_timer_task() -> None:
|
|||||||
"""
|
"""
|
||||||
reflection_engine()
|
reflection_engine()
|
||||||
|
|
||||||
|
# unused task
|
||||||
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||||
def check_read_service_task() -> Dict[str, str]:
|
# def check_read_service_task() -> Dict[str, str]:
|
||||||
"""Call read_service and write latest status to Redis.
|
# """Call read_service and write latest status to Redis.
|
||||||
|
|
||||||
Returns status data dict that gets written to Redis.
|
# Returns status data dict that gets written to Redis.
|
||||||
"""
|
# """
|
||||||
client = redis.Redis(
|
# client = redis.Redis(
|
||||||
host=settings.REDIS_HOST,
|
# host=settings.REDIS_HOST,
|
||||||
port=settings.REDIS_PORT,
|
# port=settings.REDIS_PORT,
|
||||||
db=settings.REDIS_DB,
|
# db=settings.REDIS_DB,
|
||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||||||
)
|
# )
|
||||||
try:
|
# try:
|
||||||
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
||||||
payload = {
|
# payload = {
|
||||||
"user_id": "健康检查",
|
# "user_id": "健康检查",
|
||||||
"apply_id": "健康检查",
|
# "apply_id": "健康检查",
|
||||||
"group_id": "健康检查",
|
# "group_id": "健康检查",
|
||||||
"message": "你好",
|
# "message": "你好",
|
||||||
"history": [],
|
# "history": [],
|
||||||
"search_switch": "2",
|
# "search_switch": "2",
|
||||||
}
|
# }
|
||||||
resp = requests.post(api_url, json=payload, timeout=15)
|
# resp = requests.post(api_url, json=payload, timeout=15)
|
||||||
ok = resp.status_code == 200
|
# ok = resp.status_code == 200
|
||||||
status = "Success" if ok else "Fail"
|
# status = "Success" if ok else "Fail"
|
||||||
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
||||||
error = "" if ok else resp.text
|
# error = "" if ok else resp.text
|
||||||
code = 0 if ok else 500
|
# code = 0 if ok else 500
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
status = "Fail"
|
# status = "Fail"
|
||||||
msg = "接口请求失败"
|
# msg = "接口请求失败"
|
||||||
error = str(e)
|
# error = str(e)
|
||||||
code = 500
|
# code = 500
|
||||||
|
|
||||||
data = {
|
# data = {
|
||||||
"status": status,
|
# "status": status,
|
||||||
"msg": msg,
|
# "msg": msg,
|
||||||
"error": error,
|
# "error": error,
|
||||||
"code": str(code),
|
# "code": str(code),
|
||||||
"time": str(int(time.time())),
|
# "time": str(int(time.time())),
|
||||||
}
|
# }
|
||||||
|
|
||||||
client.hset("memsci:health:read_service", mapping=data)
|
# client.hset("memsci:health:read_service", mapping=data)
|
||||||
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
||||||
|
|
||||||
return data
|
# return data
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
||||||
@@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
result = asyncio.run(_run())
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_closed():
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
result["elapsed_time"] = elapsed_time
|
result["elapsed_time"] = elapsed_time
|
||||||
result["task_id"] = self.request.id
|
result["task_id"] = self.request.id
|
||||||
@@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
result = asyncio.run(_run())
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_closed():
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
result["elapsed_time"] = elapsed_time
|
result["elapsed_time"] = elapsed_time
|
||||||
result["task_id"] = self.request.id
|
result["task_id"] = self.request.id
|
||||||
@@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
|
|||||||
"duration_seconds": duration
|
"duration_seconds": duration
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行异步函数
|
return asyncio.run(_run())
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
result = loop.run_until_complete(_run())
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|||||||
@@ -7,10 +7,6 @@ services:
|
|||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
|
||||||
- SERVER_IP=0.0.0.0
|
|
||||||
# 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位
|
|
||||||
# - MCP_SERVER_URL=
|
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
@@ -19,20 +15,53 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- default
|
- default
|
||||||
- celery
|
- celery
|
||||||
|
depends_on:
|
||||||
|
- worker-memory
|
||||||
|
- worker-document
|
||||||
|
|
||||||
# Celery worker
|
# Memory worker - Memory read/write tasks (threads pool for asyncio)
|
||||||
worker:
|
worker-memory:
|
||||||
image: redbear-mem-open:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: worker
|
container_name: worker-memory
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
command: celery -A app.celery_worker.celery_app worker --loglevel=info
|
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- celery
|
- celery
|
||||||
|
|
||||||
|
# Document worker - Document parsing tasks (prefork for CPU-bound)
|
||||||
|
worker-document:
|
||||||
|
image: redbear-mem-open:latest
|
||||||
|
container_name: worker-document
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ./files:/files
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- celery
|
||||||
|
|
||||||
|
# Celery Beat - scheduler
|
||||||
|
beat:
|
||||||
|
image: redbear-mem-open:latest
|
||||||
|
container_name: celery-beat
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ./files:/files
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
command: celery -A app.celery_worker.celery_app beat --loglevel=info
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- celery
|
||||||
|
depends_on:
|
||||||
|
- worker-memory
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
celery:
|
celery:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"bcrypt==5.0.0",
|
"bcrypt==5.0.0",
|
||||||
"billiard==4.2.2",
|
"billiard==4.2.2",
|
||||||
"celery==5.5.3",
|
"celery==5.5.3",
|
||||||
|
"flower==2.0.1",
|
||||||
"cffi==2.0.0",
|
"cffi==2.0.0",
|
||||||
"click==8.3.0",
|
"click==8.3.0",
|
||||||
"click-didyoumean==0.3.1",
|
"click-didyoumean==0.3.1",
|
||||||
@@ -138,6 +139,7 @@ dependencies = [
|
|||||||
"python-calamine>=0.4.0",
|
"python-calamine>=0.4.0",
|
||||||
"xlrd==2.0.2",
|
"xlrd==2.0.2",
|
||||||
"deprecated>=1.3.1",
|
"deprecated>=1.3.1",
|
||||||
|
"flower>=2.0.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ async-timeout==5.0.1
|
|||||||
bcrypt==5.0.0
|
bcrypt==5.0.0
|
||||||
billiard==4.2.2
|
billiard==4.2.2
|
||||||
celery==5.5.3
|
celery==5.5.3
|
||||||
|
flower==2.0.1
|
||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
click==8.3.0
|
click==8.3.0
|
||||||
click-didyoumean==0.3.1
|
click-didyoumean==0.3.1
|
||||||
|
|||||||
Reference in New Issue
Block a user