Merge remote-tracking branch 'origin/develop' into develop
This commit is contained in:
13
README.md
13
README.md
@@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface.
|
||||
## License
|
||||
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
||||
|
||||
## Acknowledgements & Community
|
||||
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
|
||||
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
|
||||
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
## Community & Support
|
||||
|
||||
Join our community to ask questions, share your work, and connect with fellow developers.
|
||||
|
||||
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues).
|
||||
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls).
|
||||
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions).
|
||||
- **WeChat**: Scan the QR code below to join our WeChat community group.
|
||||
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -14,27 +15,12 @@ celery_app = Celery(
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
# macOS 兼容性配置
|
||||
import platform
|
||||
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# 设置环境变量解决 fork 问题
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 使用 solo 池避免多进程问题
|
||||
celery_app.conf.worker_pool = 'solo'
|
||||
|
||||
# 设置唯一的节点名称
|
||||
import socket
|
||||
import time
|
||||
hostname = socket.gethostname()
|
||||
timestamp = int(time.time())
|
||||
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
||||
|
||||
# Celery 配置
|
||||
celery_app.conf.update(
|
||||
@@ -52,36 +38,47 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=30 * 60, # 30 分钟硬超时
|
||||
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
|
||||
# Worker 设置 - 针对 macOS 优化
|
||||
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
||||
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
||||
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存 1 小时
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
||||
worker_disable_rate_limits=True, # 禁用速率限制
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
# 任务路由(可选,用于不同队列)
|
||||
# task_routes={
|
||||
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
||||
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
||||
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
||||
# 'tasks.process_item': {'queue': 'default'},
|
||||
# },
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
# 自动发现任务模块
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
|
||||
# "check-read-service": {
|
||||
# "task": "app.core.memory.agent.health.check_read_service",
|
||||
# "schedule": health_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
|
||||
@@ -4,12 +4,11 @@ import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
)
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
@@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
@@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": data,
|
||||
"return_raw_results": True
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@@ -89,14 +89,15 @@ def validate_model_exists_and_active(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# First check if model exists at all (without tenant filtering)
|
||||
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
||||
|
||||
# Then check with tenant filtering
|
||||
# OPTIMIZED: Single query with tenant filter
|
||||
# We'll check tenant mismatch in the error handling
|
||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not model:
|
||||
# Model not found with tenant filter - check if it exists without filter
|
||||
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
||||
|
||||
if model_without_tenant:
|
||||
# Model exists but belongs to different tenant
|
||||
logger.warning(
|
||||
@@ -208,8 +209,11 @@ def validate_embedding_model(
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> UUID:
|
||||
"""Validate that embedding model is available and return its UUID.
|
||||
) -> tuple[UUID, str]:
|
||||
"""Validate that embedding model is available and return its UUID and name.
|
||||
|
||||
Returns:
|
||||
Tuple of (embedding_uuid, embedding_name)
|
||||
|
||||
Raises:
|
||||
InvalidConfigError: If embedding_id is not provided or invalid
|
||||
@@ -225,14 +229,19 @@ def validate_embedding_model(
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
embedding_uuid, _ = validate_and_resolve_model_id(
|
||||
embedding_uuid, embedding_name = validate_and_resolve_model_id(
|
||||
embedding_id, "embedding", db, tenant_id, required=True,
|
||||
config_id=config_id, workspace_id=workspace_id
|
||||
)
|
||||
print(100*'-')
|
||||
print(embedding_uuid)
|
||||
print(_)
|
||||
print(100*'-')
|
||||
|
||||
logger.debug(
|
||||
"Embedding model validated",
|
||||
extra={
|
||||
"embedding_uuid": str(embedding_uuid),
|
||||
"embedding_name": embedding_name,
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
if embedding_uuid is None:
|
||||
raise InvalidConfigError(
|
||||
@@ -243,7 +252,7 @@ def validate_embedding_model(
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return embedding_uuid
|
||||
return embedding_uuid, embedding_name
|
||||
|
||||
|
||||
def validate_llm_model(
|
||||
|
||||
@@ -305,12 +305,19 @@ async def search_graph(
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
# Skip activation updates if only searching summaries (optimization)
|
||||
needs_activation_update = any(
|
||||
key in include and key in results and results[key]
|
||||
for key in ['statements', 'entities', 'chunks']
|
||||
)
|
||||
|
||||
if needs_activation_update:
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -339,7 +346,7 @@ async def search_graph_by_embedding(
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||
@@ -393,7 +400,7 @@ async def search_graph_by_embedding(
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
@@ -417,14 +424,23 @@ async def search_graph_by_embedding(
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
# Skip activation updates if only searching summaries (optimization)
|
||||
needs_activation_update = any(
|
||||
key in include and key in results and results[key]
|
||||
for key in ['statements', 'entities', 'chunks']
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
|
||||
if needs_activation_update:
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
else:
|
||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
@@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal(
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
logger.warning(f"query_text cannot be empty")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal(
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -594,9 +610,9 @@ async def search_graph_by_temporal(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
|
||||
logger.debug(f"Temporal search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -623,7 +639,7 @@ async def search_graph_by_dialog_id(
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
logger.warning(f"dialog_id cannot be empty")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
@@ -642,7 +658,7 @@ async def search_graph_by_chunk_id(
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
logger.warning(f"chunk_id cannot be empty")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -679,9 +695,9 @@ async def search_graph_by_created_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -719,9 +735,9 @@ async def search_graph_by_valid_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -759,9 +775,9 @@ async def search_graph_g_created_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -799,9 +815,9 @@ async def search_graph_g_valid_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -839,9 +855,9 @@ async def search_graph_l_created_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -879,9 +895,9 @@ async def search_graph_l_valid_at(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
|
||||
logger.debug(f"Search results: {len(statements)} statements found")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
|
||||
@@ -10,11 +10,6 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.tool_service import ToolService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
result = task_service.get_task_memory_read_result(task.id)
|
||||
status = result.get("status")
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
# result = task_service.get_task_memory_read_result(task.id)
|
||||
# status = result.get("status")
|
||||
# logger.info(f"读取任务状态:{status}")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -10,15 +10,17 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import redis
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
import redis
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
||||
from app.core.memory.agent.utils.messages_tools import (
|
||||
merge_multiple_search_results,
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
@@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -404,6 +407,7 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
|
||||
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
@@ -427,13 +431,15 @@ class MemoryAgentService:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
@@ -457,6 +463,7 @@ class MemoryAgentService:
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
graph_exec_start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
@@ -513,6 +520,9 @@ class MemoryAgentService:
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
graph_exec_time = time.time() - graph_exec_start
|
||||
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||||
|
||||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
|
||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
@@ -570,6 +580,8 @@ class MemoryAgentService:
|
||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||
|
||||
# Log successful operation
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -587,7 +599,8 @@ class MemoryAgentService:
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
|
||||
@@ -125,7 +125,11 @@ class MemoryConfigService:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
# Step 1: Get config and workspace
|
||||
db_query_start = time.time()
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -144,16 +148,20 @@ class MemoryConfigService:
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
# Step 2: Validate embedding model (returns both UUID and name)
|
||||
embed_start = time.time()
|
||||
embedding_uuid, embedding_name = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
embed_time = time.time() - embed_start
|
||||
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
|
||||
|
||||
# Resolve LLM model
|
||||
# Step 3: Resolve LLM model
|
||||
llm_start = time.time()
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
@@ -163,8 +171,11 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
llm_time = time.time() - llm_start
|
||||
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
|
||||
|
||||
# Resolve optional rerank model
|
||||
# Step 4: Resolve optional rerank model
|
||||
rerank_start = time.time()
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
@@ -177,16 +188,12 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
rerank_time = time.time() - rerank_start
|
||||
if memory_config.rerank_id:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Note: embedding_name is now returned from validate_embedding_model above
|
||||
# No need for redundant query!
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
|
||||
173
api/app/tasks.py
173
api/app/tasks.py
@@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result = asyncio.run(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
}
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
@@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result = asyncio.run(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
@@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
}
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
@@ -600,53 +564,53 @@ def reflection_timer_task() -> None:
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
def check_read_service_task() -> Dict[str, str]:
|
||||
"""Call read_service and write latest status to Redis.
|
||||
# unused task
|
||||
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
# def check_read_service_task() -> Dict[str, str]:
|
||||
# """Call read_service and write latest status to Redis.
|
||||
|
||||
Returns status data dict that gets written to Redis.
|
||||
"""
|
||||
client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||||
)
|
||||
try:
|
||||
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
||||
payload = {
|
||||
"user_id": "健康检查",
|
||||
"apply_id": "健康检查",
|
||||
"group_id": "健康检查",
|
||||
"message": "你好",
|
||||
"history": [],
|
||||
"search_switch": "2",
|
||||
}
|
||||
resp = requests.post(api_url, json=payload, timeout=15)
|
||||
ok = resp.status_code == 200
|
||||
status = "Success" if ok else "Fail"
|
||||
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
||||
error = "" if ok else resp.text
|
||||
code = 0 if ok else 500
|
||||
except Exception as e:
|
||||
status = "Fail"
|
||||
msg = "接口请求失败"
|
||||
error = str(e)
|
||||
code = 500
|
||||
# Returns status data dict that gets written to Redis.
|
||||
# """
|
||||
# client = redis.Redis(
|
||||
# host=settings.REDIS_HOST,
|
||||
# port=settings.REDIS_PORT,
|
||||
# db=settings.REDIS_DB,
|
||||
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||||
# )
|
||||
# try:
|
||||
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
||||
# payload = {
|
||||
# "user_id": "健康检查",
|
||||
# "apply_id": "健康检查",
|
||||
# "group_id": "健康检查",
|
||||
# "message": "你好",
|
||||
# "history": [],
|
||||
# "search_switch": "2",
|
||||
# }
|
||||
# resp = requests.post(api_url, json=payload, timeout=15)
|
||||
# ok = resp.status_code == 200
|
||||
# status = "Success" if ok else "Fail"
|
||||
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
||||
# error = "" if ok else resp.text
|
||||
# code = 0 if ok else 500
|
||||
# except Exception as e:
|
||||
# status = "Fail"
|
||||
# msg = "接口请求失败"
|
||||
# error = str(e)
|
||||
# code = 500
|
||||
|
||||
data = {
|
||||
"status": status,
|
||||
"msg": msg,
|
||||
"error": error,
|
||||
"code": str(code),
|
||||
"time": str(int(time.time())),
|
||||
}
|
||||
# data = {
|
||||
# "status": status,
|
||||
# "msg": msg,
|
||||
# "error": error,
|
||||
# "code": str(code),
|
||||
# "time": str(int(time.time())),
|
||||
# }
|
||||
|
||||
client.hset("memsci:health:read_service", mapping=data)
|
||||
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
||||
# client.hset("memsci:health:read_service", mapping=data)
|
||||
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
||||
|
||||
return data
|
||||
# return data
|
||||
|
||||
|
||||
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
||||
@@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result = asyncio.run(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
@@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result = asyncio.run(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
@@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
|
||||
"duration_seconds": duration
|
||||
}
|
||||
|
||||
# 运行异步函数
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
return asyncio.run(_run())
|
||||
|
||||
@@ -7,10 +7,6 @@ services:
|
||||
- "8002:8000"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- SERVER_IP=0.0.0.0
|
||||
# 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位
|
||||
# - MCP_SERVER_URL=
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
@@ -19,20 +15,53 @@ services:
|
||||
networks:
|
||||
- default
|
||||
- celery
|
||||
depends_on:
|
||||
- worker-memory
|
||||
- worker-document
|
||||
|
||||
# Celery worker
|
||||
worker:
|
||||
# Memory worker - Memory read/write tasks (threads pool for asyncio)
|
||||
worker-memory:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: worker
|
||||
container_name: worker-memory
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: celery -A app.celery_worker.celery_app worker --loglevel=info
|
||||
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Document worker - Document parsing tasks (prefork for CPU-bound)
|
||||
worker-document:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: worker-document
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Celery Beat - scheduler
|
||||
beat:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: celery-beat
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: celery -A app.celery_worker.celery_app beat --loglevel=info
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- celery
|
||||
depends_on:
|
||||
- worker-memory
|
||||
|
||||
networks:
|
||||
celery:
|
||||
|
||||
@@ -13,6 +13,7 @@ dependencies = [
|
||||
"bcrypt==5.0.0",
|
||||
"billiard==4.2.2",
|
||||
"celery==5.5.3",
|
||||
"flower==2.0.1",
|
||||
"cffi==2.0.0",
|
||||
"click==8.3.0",
|
||||
"click-didyoumean==0.3.1",
|
||||
@@ -138,6 +139,7 @@ dependencies = [
|
||||
"python-calamine>=0.4.0",
|
||||
"xlrd==2.0.2",
|
||||
"deprecated>=1.3.1",
|
||||
"flower>=2.0.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -6,6 +6,7 @@ async-timeout==5.0.1
|
||||
bcrypt==5.0.0
|
||||
billiard==4.2.2
|
||||
celery==5.5.3
|
||||
flower==2.0.1
|
||||
cffi==2.0.0
|
||||
click==8.3.0
|
||||
click-didyoumean==0.3.1
|
||||
|
||||
Reference in New Issue
Block a user