Files
MemoryBear/api/app/services/pilot_run_service.py
Ke Sun 283c64a358 refactor(memory): restructure memory system and improve configuration management
- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
2025-12-23 17:17:04 +08:00

220 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")