Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
lixinyue
2026-01-21 18:10:46 +08:00
12 changed files with 254 additions and 259 deletions

View File

@@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface.
## License ## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file. This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Acknowledgements & Community ## Community & Support
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines. Join our community to ask questions, share your work, and connect with fellow developers.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -1,4 +1,5 @@
import os import os
import platform
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
@@ -14,27 +15,12 @@ celery_app = Celery(
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
) )
# 配置使用本地队列,避免与远程 worker 冲突 # Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'localhost_test_wyl' celery_app.conf.task_default_queue = 'memory_tasks'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置 # macOS 兼容性配置
import platform if platform.system() == 'Darwin':
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置 # Celery 配置
celery_app.conf.update( celery_app.conf.update(
@@ -52,36 +38,47 @@ celery_app.conf.update(
task_ignore_result=False, task_ignore_result=False,
# 超时设置 # 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时 task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时 task_soft_time_limit=1500, # 25分钟软超时
# Worker 设置 - 针对 macOS 优化 # Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# 结果过期时间 # 结果过期时间
result_expires=3600, # 结果保存 1 小时 result_expires=3600, # 结果保存1小时
# 任务确认设置 # 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失 task_acks_late=True,
worker_disable_rate_limits=True, # 禁用速率限制 task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# 任务路由(可选,用于不同队列) # FLower setting
# task_routes={ worker_send_task_events=True,
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, task_send_sent_event=True,
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, # task routing
# 'tasks.process_item': {'queue': 'default'}, task_routes={
# }, # Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
},
) )
# 自动发现任务模块 # 自动发现任务模块
celery_app.autodiscover_tasks(['app']) celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks # Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
# 构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": { "run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task", "task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule, "schedule": workspace_reflection_schedule,

View File

@@ -4,12 +4,11 @@ import os
import time import time
from app.core.logging_config import get_agent_logger, log_time from app.core.logging_config import get_agent_logger, log_time
from app.db import get_db
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
SummaryResponse, SummaryResponse,
) )
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
search_params = { search_params = {
"group_id": group_id, "group_id": group_id,
"question": data, "question": data,
"return_raw_results": True "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
} }
try: try:

View File

@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
start_time = time.time() start_time = time.time()
try: try:
# First check if model exists at all (without tenant filtering) # OPTIMIZED: Single query with tenant filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) # We'll check tenant mismatch in the error handling
# Then check with tenant filtering
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
if not model: if not model:
# Model not found with tenant filter - check if it exists without filter
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
if model_without_tenant: if model_without_tenant:
# Model exists but belongs to different tenant # Model exists but belongs to different tenant
logger.warning( logger.warning(
@@ -208,8 +209,11 @@ def validate_embedding_model(
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> UUID: ) -> tuple[UUID, str]:
"""Validate that embedding model is available and return its UUID. """Validate that embedding model is available and return its UUID and name.
Returns:
Tuple of (embedding_uuid, embedding_name)
Raises: Raises:
InvalidConfigError: If embedding_id is not provided or invalid InvalidConfigError: If embedding_id is not provided or invalid
@@ -225,14 +229,19 @@ def validate_embedding_model(
workspace_id=workspace_id workspace_id=workspace_id
) )
embedding_uuid, _ = validate_and_resolve_model_id( embedding_uuid, embedding_name = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True, embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id config_id=config_id, workspace_id=workspace_id
) )
print(100*'-')
print(embedding_uuid) logger.debug(
print(_) "Embedding model validated",
print(100*'-') extra={
"embedding_uuid": str(embedding_uuid),
"embedding_name": embedding_name,
"config_id": config_id
}
)
if embedding_uuid is None: if embedding_uuid is None:
raise InvalidConfigError( raise InvalidConfigError(
@@ -243,7 +252,7 @@ def validate_embedding_model(
workspace_id=workspace_id workspace_id=workspace_id
) )
return embedding_uuid return embedding_uuid, embedding_name
def validate_llm_model( def validate_llm_model(

View File

@@ -305,12 +305,19 @@ async def search_graph(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation( # Skip activation updates if only searching summaries (optimization)
connector=connector, needs_activation_update = any(
results=results, key in include and key in results and results[key]
group_id=group_id for key in ['statements', 'entities', 'chunks']
) )
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results return results
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s") logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []} return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time() # Skip activation updates if only searching summaries (optimization)
results = await _update_search_results_activation( needs_activation_update = any(
connector=connector, key in include and key in results and results[key]
results=results, for key in ['statements', 'entities', 'chunks']
group_id=group_id
) )
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s") if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
else:
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
print(f"query_text不能为空") logger.warning(f"query_text cannot be empty")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
print(f"查询结果为:\n{statements}") logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Temporal search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
print(f"dialog_id不能为空") logger.warning(f"dialog_id cannot be empty")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
print(f"chunk_id不能为空") logger.warning(f"chunk_id cannot be empty")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
print(f"查询结果为:\n{statements}") logger.debug(f"Search results: {len(statements)} statements found")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}

View File

@@ -10,11 +10,6 @@ import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger() logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel): class KnowledgeRetrievalInput(BaseModel):
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"app.core.memory.agent.read_message", "app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
) )
result = task_service.get_task_memory_read_result(task.id) # result = task_service.get_task_memory_read_result(task.id)
status = result.get("status") # status = result.get("status")
logger.info(f"读取任务状态:{status}") # logger.info(f"读取任务状态:{status}")
finally: finally:
db.close() db.close()

View File

@@ -10,15 +10,17 @@ import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
import redis
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
@@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -404,6 +407,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
@@ -427,13 +431,15 @@ class MemoryAgentService:
audit_logger = None audit_logger = None
config_load_start = time.time()
try: try:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -457,6 +463,7 @@ class MemoryAgentService:
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": group_id}}
@@ -513,6 +520,9 @@ class MemoryAgentService:
if summary_n and summary_n != [] and summary_n != {}: if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n) _intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs) optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
@@ -570,6 +580,8 @@ class MemoryAgentService:
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation # Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -587,7 +599,8 @@ class MemoryAgentService:
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}" error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg) total_time = time.time() - start_time
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(

View File

@@ -125,7 +125,11 @@ class MemoryConfigService:
try: try:
validated_config_id = _validate_config_id(config_id) validated_config_id = _validate_config_id(config_id)
# Step 1: Get config and workspace
db_query_start = time.time()
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result: if not result:
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
config_logger.error( config_logger.error(
@@ -144,16 +148,20 @@ class MemoryConfigService:
memory_config, workspace = result memory_config, workspace = result
# Validate embedding model # Step 2: Validate embedding model (returns both UUID and name)
embedding_uuid = validate_embedding_model( embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id, validated_config_id,
memory_config.embedding_id, memory_config.embedding_id,
self.db, self.db,
workspace.tenant_id, workspace.tenant_id,
workspace.id, workspace.id,
) )
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
# Resolve LLM model # Step 3: Resolve LLM model
llm_start = time.time()
llm_uuid, llm_name = validate_and_resolve_model_id( llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id, memory_config.llm_id,
"llm", "llm",
@@ -163,8 +171,11 @@ class MemoryConfigService:
config_id=validated_config_id, config_id=validated_config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
llm_time = time.time() - llm_start
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
# Resolve optional rerank model # Step 4: Resolve optional rerank model
rerank_start = time.time()
rerank_uuid = None rerank_uuid = None
rerank_name = None rerank_name = None
if memory_config.rerank_id: if memory_config.rerank_id:
@@ -177,16 +188,12 @@ class MemoryConfigService:
config_id=validated_config_id, config_id=validated_config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
rerank_time = time.time() - rerank_start
if memory_config.rerank_id:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
# Get embedding model name # Note: embedding_name is now returned from validate_embedding_model above
embedding_name, _ = validate_model_exists_and_active( # No need for redundant query!
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object # Create immutable MemoryConfig object
config = MemoryConfig( config = MemoryConfig(

View File

@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db.close() db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
db.close() db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
} }
except BaseException as e: except BaseException as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
@@ -600,53 +564,53 @@ def reflection_timer_task() -> None:
""" """
reflection_engine() reflection_engine()
# unused task
@celery_app.task(name="app.core.memory.agent.health.check_read_service") # @celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]: # def check_read_service_task() -> Dict[str, str]:
"""Call read_service and write latest status to Redis. # """Call read_service and write latest status to Redis.
Returns status data dict that gets written to Redis. # Returns status data dict that gets written to Redis.
""" # """
client = redis.Redis( # client = redis.Redis(
host=settings.REDIS_HOST, # host=settings.REDIS_HOST,
port=settings.REDIS_PORT, # port=settings.REDIS_PORT,
db=settings.REDIS_DB, # db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None # password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
) # )
try: # try:
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" # api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
payload = { # payload = {
"user_id": "健康检查", # "user_id": "健康检查",
"apply_id": "健康检查", # "apply_id": "健康检查",
"group_id": "健康检查", # "group_id": "健康检查",
"message": "你好", # "message": "你好",
"history": [], # "history": [],
"search_switch": "2", # "search_switch": "2",
} # }
resp = requests.post(api_url, json=payload, timeout=15) # resp = requests.post(api_url, json=payload, timeout=15)
ok = resp.status_code == 200 # ok = resp.status_code == 200
status = "Success" if ok else "Fail" # status = "Success" if ok else "Fail"
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" # msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
error = "" if ok else resp.text # error = "" if ok else resp.text
code = 0 if ok else 500 # code = 0 if ok else 500
except Exception as e: # except Exception as e:
status = "Fail" # status = "Fail"
msg = "接口请求失败" # msg = "接口请求失败"
error = str(e) # error = str(e)
code = 500 # code = 500
data = { # data = {
"status": status, # "status": status,
"msg": msg, # "msg": msg,
"error": error, # "error": error,
"code": str(code), # "code": str(code),
"time": str(int(time.time())), # "time": str(int(time.time())),
} # }
client.hset("memsci:health:read_service", mapping=data) # client.hset("memsci:health:read_service", mapping=data)
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) # client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
return data # return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all") @celery_app.task(name="app.controllers.memory_storage_controller.search_all")
@@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 result = asyncio.run(_run())
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id result["task_id"] = self.request.id
@@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration "duration_seconds": duration
} }
# 运行异步函数 return asyncio.run(_run())
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()

View File

@@ -7,10 +7,6 @@ services:
- "8002:8000" - "8002:8000"
env_file: env_file:
- .env - .env
environment:
- SERVER_IP=0.0.0.0
# 如果代码里必须要 MCP_SERVER_URL可以先注释或指向占位
# - MCP_SERVER_URL=
volumes: volumes:
- ./files:/files - ./files:/files
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro
@@ -19,20 +15,53 @@ services:
networks: networks:
- default - default
- celery - celery
depends_on:
- worker-memory
- worker-document
# Celery worker # Memory worker - Memory read/write tasks (threads pool for asyncio)
worker: worker-memory:
image: redbear-mem-open:latest image: redbear-mem-open:latest
container_name: worker container_name: worker-memory
env_file: env_file:
- .env - .env
volumes: volumes:
- ./files:/files - ./files:/files
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker --loglevel=info command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
restart: unless-stopped restart: unless-stopped
networks: networks:
- celery - celery
# Document worker - Document parsing tasks (prefork for CPU-bound)
worker-document:
image: redbear-mem-open:latest
container_name: worker-document
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
container_name: celery-beat
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app beat --loglevel=info
restart: unless-stopped
networks:
- celery
depends_on:
- worker-memory
networks: networks:
celery: celery:

View File

@@ -13,6 +13,7 @@ dependencies = [
"bcrypt==5.0.0", "bcrypt==5.0.0",
"billiard==4.2.2", "billiard==4.2.2",
"celery==5.5.3", "celery==5.5.3",
"flower==2.0.1",
"cffi==2.0.0", "cffi==2.0.0",
"click==8.3.0", "click==8.3.0",
"click-didyoumean==0.3.1", "click-didyoumean==0.3.1",
@@ -138,6 +139,7 @@ dependencies = [
"python-calamine>=0.4.0", "python-calamine>=0.4.0",
"xlrd==2.0.2", "xlrd==2.0.2",
"deprecated>=1.3.1", "deprecated>=1.3.1",
"flower>=2.0.1",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -6,6 +6,7 @@ async-timeout==5.0.1
bcrypt==5.0.0 bcrypt==5.0.0
billiard==4.2.2 billiard==4.2.2
celery==5.5.3 celery==5.5.3
flower==2.0.1
cffi==2.0.0 cffi==2.0.0
click==8.3.0 click==8.3.0
click-didyoumean==0.3.1 click-didyoumean==0.3.1