Files
MemoryBear/app/core/memory/main.py
2025-11-30 18:22:17 +08:00

333 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MemSci 记忆系统主入口 - 重构版本
该模块是重构后的记忆系统主入口,使用新的模块化架构。
旧版本入口app/core/memory/src/main.py已删除。
主要功能:
1. 协调整个知识提取流水线
2. 支持试运行模式和正常运行模式
3. 使用重构后的 storage_services 模块
4. 提供统一的配置管理和日志记录
作者Lance77
日期2025-11-22
"""
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
import os
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional
from dotenv import load_dotenv
# 导入重构后的模块
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
# 导入数据加载函数
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
get_chunked_dialogs_with_preprocessing,
get_chunked_dialogs_from_preprocessed,
)
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.logging_config import get_memory_logger, log_time
load_dotenv()
logger = get_memory_logger(__name__)
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
"""
记忆系统主流程 - 重构版本
该函数是重构后的主入口,使用新的模块化架构。
Args:
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
工作流程:
1. 初始化客户端和配置
2. 加载或准备数据
3. 执行知识提取流水线
4. 保存结果(正常模式)或输出结果(试运行模式)
"""
print("=" * 60)
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
print("Using model ID:", config_defs.SELECTED_LLM_ID)
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
print("=" * 60)
# 初始化日志
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 加载或准备数据
logger.info("Loading data...")
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
step_start = time.time()
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
# 试运行模式:处理前端传入的对话文本
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
import re
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
# 创建对话上下文和对话数据
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
logger.info("Loading data from testdata.json...")
test_data_path = os.path.join(
os.path.dirname(__file__), "data", "testdata.json"
)
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
indices=config_defs.SELECTED_TEST_DATA_INDICES,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
# 从 runtime.json 加载配置(已经过数据库覆写)
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
)
# 解包 extraction_result tuple
# extraction_result 是一个包含 7 个元素的 tuple:
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
# statement_chunk_edges, statement_entity_edges, entity_edges)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 步骤 5: 保存结果或输出结果
if is_pilot_run:
logger.info("Pilot run mode: Skipping Neo4j save")
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
print("提取结果已生成,可在相关输出中查看。")
else:
logger.info("Normal mode: Saving to Neo4j...")
step_start = time.time()
# 创建索引和约束
try:
from app.repositories.neo4j.create_indexes import (
create_fulltext_indexes,
create_unique_constraints,
)
await create_fulltext_indexes()
await create_unique_constraints()
logger.info("Successfully created indexes and constraints")
except Exception as e:
logger.error(f"Error creating indexes/constraints: {e}")
# 保存数据到 Neo4j
try:
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
entity_edges=entity_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
print("\n✓ 成功保存所有数据到 Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
print("\n⚠ 部分数据保存到 Neo4j 失败")
except Exception as e:
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
print(f"\n✗ 保存到 Neo4j 失败: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# 步骤 6: 生成记忆摘要(可选)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
if not is_pilot_run:
# 保存记忆摘要到 Neo4j
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
await ms_connector.close()
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
print(f"\n✗ 流水线执行失败: {e}")
raise
finally:
# 清理资源
try:
await neo4j_connector.close()
except Exception:
pass
# 记录总时间
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# 添加完成标记
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
print("\n" + "=" * 60)
print(f"✓ 流水线执行完成")
print(f"✓ 总耗时: {total_time:.2f}")
print(f"✓ 详细日志: {log_file}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())