220 lines
8.0 KiB
Python
220 lines
8.0 KiB
Python
"""
|
||
Pilot Run Service - 试运行服务
|
||
|
||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Awaitable, Callable, Optional
|
||
|
||
from app.core.logging_config import get_memory_logger, log_time
|
||
from app.core.memory.models.message_models import (
|
||
ConversationContext,
|
||
ConversationMessage,
|
||
DialogData,
|
||
)
|
||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||
ExtractionOrchestrator,
|
||
get_chunked_dialogs_from_preprocessed,
|
||
)
|
||
from app.core.memory.utils.config.config_utils import (
|
||
get_pipeline_config,
|
||
)
|
||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.schemas.memory_config_schema import MemoryConfig
|
||
from sqlalchemy.orm import Session
|
||
|
||
logger = get_memory_logger(__name__)
|
||
|
||
|
||
async def run_pilot_extraction(
|
||
memory_config: MemoryConfig,
|
||
dialogue_text: str,
|
||
db: Session,
|
||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||
) -> None:
|
||
"""
|
||
执行试运行模式的知识提取流水线。
|
||
|
||
Args:
|
||
memory_config: 从数据库加载的内存配置对象
|
||
dialogue_text: 输入的对话文本
|
||
progress_callback: 可选的进度回调函数
|
||
- 参数1 (stage): 当前处理阶段标识符
|
||
- 参数2 (message): 人类可读的进度消息
|
||
- 参数3 (data): 可选的附加数据字典
|
||
"""
|
||
log_file = "logs/time.log"
|
||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||
|
||
pipeline_start = time.time()
|
||
neo4j_connector = None
|
||
|
||
try:
|
||
# 步骤 1: 初始化客户端
|
||
logger.info("Initializing clients...")
|
||
step_start = time.time()
|
||
|
||
client_factory = MemoryClientFactory(db)
|
||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||
|
||
neo4j_connector = Neo4jConnector()
|
||
|
||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||
|
||
# 步骤 2: 解析对话文本
|
||
logger.info("Parsing dialogue text...")
|
||
step_start = time.time()
|
||
|
||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||
messages = [
|
||
ConversationMessage(role=r, msg=c.strip())
|
||
for r, c in matches
|
||
if c.strip()
|
||
]
|
||
|
||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||
if not messages:
|
||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||
|
||
context = ConversationContext(msgs=messages)
|
||
dialog = DialogData(
|
||
context=context,
|
||
ref_id="pilot_dialog_1",
|
||
end_user_id=str(memory_config.workspace_id),
|
||
user_id=str(memory_config.tenant_id),
|
||
apply_id=str(memory_config.config_id),
|
||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||
)
|
||
|
||
if progress_callback:
|
||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||
|
||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||
data=[dialog],
|
||
chunker_strategy=memory_config.chunker_strategy,
|
||
llm_client=llm_client,
|
||
)
|
||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||
|
||
# 进度回调:输出每个分块的结果
|
||
if progress_callback:
|
||
for dlg in chunked_dialogs:
|
||
for i, chunk in enumerate(dlg.chunks):
|
||
chunk_result = {
|
||
"chunk_index": i + 1,
|
||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||
"full_length": len(chunk.content),
|
||
"dialog_id": dlg.id,
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
}
|
||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||
|
||
preprocessing_summary = {
|
||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||
"total_dialogs": len(chunked_dialogs),
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
}
|
||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||
|
||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||
|
||
# 步骤 3: 初始化流水线编排器
|
||
logger.info("Initializing extraction orchestrator...")
|
||
step_start = time.time()
|
||
|
||
config = get_pipeline_config(memory_config)
|
||
logger.info(
|
||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||
)
|
||
|
||
orchestrator = ExtractionOrchestrator(
|
||
llm_client=llm_client,
|
||
embedder_client=embedder_client,
|
||
connector=neo4j_connector,
|
||
config=config,
|
||
progress_callback=progress_callback,
|
||
embedding_id=str(memory_config.embedding_model_id),
|
||
)
|
||
|
||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||
|
||
# 步骤 4: 执行知识提取流水线
|
||
logger.info("Running extraction pipeline...")
|
||
step_start = time.time()
|
||
|
||
if progress_callback:
|
||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||
|
||
extraction_result = await orchestrator.run(
|
||
dialog_data_list=chunked_dialogs,
|
||
is_pilot_run=True,
|
||
)
|
||
|
||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||
(
|
||
dialogue_nodes,
|
||
chunk_nodes,
|
||
statement_nodes,
|
||
entity_nodes,
|
||
statement_chunk_edges,
|
||
statement_entity_edges,
|
||
entity_edges,
|
||
) = extraction_result
|
||
|
||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||
|
||
if progress_callback:
|
||
await progress_callback("generating_results", "正在生成结果...")
|
||
|
||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||
try:
|
||
logger.info("Generating memory summaries...")
|
||
step_start = time.time()
|
||
|
||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||
memory_summary_generation,
|
||
)
|
||
|
||
summaries = await memory_summary_generation(
|
||
chunked_dialogs,
|
||
llm_client=llm_client,
|
||
embedder_client=embedder_client,
|
||
)
|
||
|
||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||
except Exception as e:
|
||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||
|
||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||
raise
|
||
finally:
|
||
if neo4j_connector:
|
||
try:
|
||
await neo4j_connector.close()
|
||
except Exception:
|
||
pass
|
||
|
||
total_time = time.time() - pipeline_start
|
||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||
|
||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|