feat(memory): add WritePipeline and MemoryService facade

Introduce a layered pipeline architecture for the memory write flow:
- WritePipeline: orchestrates preprocess → extract → store → cluster → summarize
  with deadlock retry, resource cleanup, and pilot-run support
- MemoryService: facade that delegates to WritePipeline, placeholder methods
  for read/forget/reflect
- BearLogger: structured step-level logging with perf threshold alerts
- Shadow pipeline integration in MemoryAgentService (env-gated pilot run)

Also includes:
- Fix deprecated SQLAlchemy declarative_base import
- Extend Neo4j Entity fulltext index to cover description and aliases
- Migrate Pydantic schemas to v2 (ConfigDict, field_validator)
This commit is contained in:
lanceyq
2026-04-17 19:06:02 +08:00
parent feae2f2e1e
commit 41535c34e6
12 changed files with 1000 additions and 57 deletions

View File

@@ -1,58 +1,113 @@
from sqlalchemy.orm import Session """
MemoryService — 记忆模块统一入口Facade
from app.core.memory.enums import StorageType, SearchStrategy 所有外部调用方controllers、Celery tasks、API service只依赖此类。
from app.core.memory.models.service_models import MemoryContext, MemorySearchResult
from app.core.memory.pipelines.memory_read import ReadPipeLine 职责:
from app.db import get_db_context - 接收已加载的 MemoryConfig选择并调用对应的 Pipeline
from app.services.memory_config_service import MemoryConfigService - 不包含任何业务逻辑实现
- 不直接操作数据库或 LLM
依赖方向:外部调用方 → MemoryService → Pipeline → Engine → Repository
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
if TYPE_CHECKING:
from app.core.memory.pipelines.write_pipeline import WriteResult
from app.schemas.memory_config_schema import MemoryConfig
logger = logging.getLogger(__name__)
class MemoryService: class MemoryService:
"""记忆模块统一入口
所有外部调用方controllers、Celery tasks、API service只依赖此类。
设计决策:
- __init__ 接收已加载的 MemoryConfig而非 config_id
配置加载的职责留在调用方MemoryAgentService
因为调用方需要 config 做其他事情(如感知记忆处理)。
- 未实现的方法抛出 NotImplementedError明确标记待实现状态。
"""
def __init__( def __init__(
self, self,
db: Session, memory_config: MemoryConfig,
config_id: str | None, end_user_id: str,
end_user_id: str,
workspace_id: str | None = None,
storage_type: str = "neo4j",
user_rag_memory_id: str | None = None,
language: str = "zh",
): ):
config_service = MemoryConfigService(db) """
memory_config = None Args:
if config_id is not None: memory_config: 已加载的不可变配置对象
memory_config = config_service.load_memory_config( end_user_id: 终端用户 ID
config_id=config_id, """
workspace_id=workspace_id, self.memory_config = memory_config
service_name="MemoryService", self.end_user_id = end_user_id
)
if memory_config is None and storage_type.lower() == "neo4j": async def write(
raise RuntimeError("Memory configuration for unspecified users") self,
self.ctx = MemoryContext( messages: List[dict],
end_user_id=end_user_id, language: str = "zh",
memory_config=memory_config, ref_id: str = "",
storage_type=StorageType(storage_type), is_pilot_run: bool = False,
user_rag_memory_id=user_rag_memory_id, progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
) -> WriteResult:
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
Args:
messages: 结构化消息 [{"role": "user"/"assistant", "content": "..."}]
language: 语言 ("zh" | "en")
ref_id: 引用 ID为空则自动生成
is_pilot_run: 试运行模式(只萃取不写入)
progress_callback: 可选的进度回调
Returns:
WriteResult 包含状态和统计信息
"""
from app.core.memory.pipelines.write_pipeline import WritePipeline
pipeline = WritePipeline(
memory_config=self.memory_config,
end_user_id=self.end_user_id,
language=language, language=language,
progress_callback=progress_callback,
)
return await pipeline.run(
messages=messages,
ref_id=ref_id,
is_pilot_run=is_pilot_run,
) )
async def write(self, messages: list[dict]) -> str:
raise NotImplementedError
async def read( async def read(
self, self, query: str, history: list, search_switch: str
query: str, ) -> dict:
search_switch: SearchStrategy, """读取记忆:根据 search_switch 选择快速/深度路径"""
limit: int = 10, raise NotImplementedError("ReadPipeline 尚未实现")
) -> MemorySearchResult:
with get_db_context() as db:
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict: # async def search(
raise NotImplementedError # self,
# query: str,
# search_type: str = "hybrid",
# limit: int = 10,
# ) -> dict:
# """独立检索:不经过 LangGraph直接执行混合检索"""
# raise NotImplementedError("SearchPipeline 尚未实现")
async def forget(
self, max_batch: int = 100, min_days: int = 30
) -> dict:
"""遗忘:识别低激活节点并融合"""
raise NotImplementedError("ForgettingPipeline 尚未实现")
async def reflect(self) -> dict: async def reflect(self) -> dict:
raise NotImplementedError """反思:检测事实冲突并修正"""
raise NotImplementedError("ReflectionPipeline 尚未实现")
async def cluster(self, new_entity_ids: list[str] = None) -> None: # async def cluster(self, new_entity_ids: list[str] = None) -> None:
raise NotImplementedError # """聚类:全量初始化或增量更新社区"""
# raise NotImplementedError("ClusteringPipeline 尚未实现")

View File

@@ -0,0 +1,26 @@
"""
Memory Pipelines — 记忆模块流水线编排层
每条 Pipeline 定义一个完整的业务流程,按顺序编排多个 Engine 的调用。
Pipeline 不包含业务逻辑实现,只做步骤编排和数据传递。
"""
def __getattr__(name):
"""延迟导入,避免循环依赖"""
if name in ("WritePipeline", "ExtractionResult", "WriteResult"):
from app.core.memory.pipelines.write_pipeline import (
ExtractionResult,
WritePipeline,
WriteResult,
)
_exports = {
"WritePipeline": WritePipeline,
"ExtractionResult": ExtractionResult,
"WriteResult": WriteResult,
}
return _exports[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]

View File

@@ -0,0 +1,649 @@
"""
WritePipeline — 记忆写入流水线
编排完整的写入流程:预处理 → 萃取 → 存储 → 聚类 → 摘要。
不包含业务逻辑实现,只做步骤编排和数据传递。
设计原则:
- Pipeline 不直接操作数据库,通过 Engine / Repository 完成
- Pipeline 不包含 LLM 调用逻辑,通过 ExtractionOrchestrator 完成
- Pipeline 负责资源生命周期管理(客户端初始化 / 连接关闭)
- Pipeline 负责错误边界划分(哪些错误中断流程,哪些吞掉继续)
依赖方向Facade → Pipeline → Engine → Repository单向不允许反向调用
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
if TYPE_CHECKING:
from app.core.memory.models.graph_models import ExtractedEntityNode
from app.core.memory.models.message_models import DialogData
from app.schemas.memory_config_schema import MemoryConfig
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# 数据结构
# ──────────────────────────────────────────────
@dataclass
class ExtractionResult:
"""萃取步骤的结构化输出,替代 ExtractionOrchestrator.run() 返回的裸元组。
字段与 ExtractionOrchestrator.run() 的 10 元素返回值一一对应:
[0] dialogue_nodes → self.dialogue_nodes
[1] chunk_nodes → self.chunk_nodes
[2] statement_nodes → self.statement_nodes
[3] entity_nodes → self.entity_nodes
[4] perceptual_nodes → self.perceptual_nodes
[5] stmt_chunk_edges → self.stmt_chunk_edges
[6] stmt_entity_edges → self.stmt_entity_edges
[7] entity_entity_edges → self.entity_entity_edges
[8] perceptual_edges → self.perceptual_edges
[9] dialog_data_list → self.dialog_data_list
注意:字段类型使用 List[Any] 而非具体的 graph_models 类型,
避免在模块加载时触发循环依赖。Pipeline 只做数据传递,不检查具体类型。
"""
dialogue_nodes: List[Any]
chunk_nodes: List[Any]
statement_nodes: List[Any]
entity_nodes: List[Any]
perceptual_nodes: List[Any]
stmt_chunk_edges: List[Any]
stmt_entity_edges: List[Any]
entity_entity_edges: List[Any]
perceptual_edges: List[Any]
dialog_data_list: List[Any]
@property
def stats(self) -> Dict[str, int]:
"""返回统计摘要,用于 WriteResult 和日志"""
return {
"dialogue_count": len(self.dialogue_nodes),
"chunk_count": len(self.chunk_nodes),
"statement_count": len(self.statement_nodes),
"entity_count": len(self.entity_nodes),
"perceptual_count": len(self.perceptual_nodes),
"relation_count": len(self.entity_entity_edges),
}
@dataclass
class WriteResult:
"""写入流水线的最终输出,返回给 MemoryService / MemoryAgentService"""
status: str # "success" | "pilot_complete" | "failed"
extraction: Optional[Dict[str, int]] = None # ExtractionResult.stats
error: Optional[str] = None # 失败时的错误信息
elapsed_seconds: float = 0.0 # 总耗时(秒)
# ──────────────────────────────────────────────
# WritePipeline
# ──────────────────────────────────────────────
class WritePipeline:
"""
记忆写入流水线
编排完整的写入流程:预处理 → 萃取 → 存储 → 聚类 → 摘要。
"""
def __init__(
self,
memory_config: MemoryConfig,
end_user_id: str,
language: str = "zh",
progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
):
"""
Args:
memory_config: 不可变的记忆配置对象(从数据库加载)
end_user_id: 终端用户 ID
language: 语言 ("zh" | "en")
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None]
"""
self.memory_config = memory_config
self.end_user_id = end_user_id
self.language = language
self.progress_callback = progress_callback
# 延迟初始化的客户端
self._llm_client = None
self._embedder_client = None
self._neo4j_connector = None
# ──────────────────────────────────────────────
# 公开接口
# ──────────────────────────────────────────────
async def run(
self,
messages: List[dict],
ref_id: str = "",
is_pilot_run: bool = False,
) -> WriteResult:
"""
执行完整的写入流水线。
Args:
messages: 结构化消息 [{"role": "user"/"assistant", "content": "..."}]
ref_id: 引用 ID为空则自动生成
is_pilot_run: 试运行模式(只萃取不写入)
Returns:
WriteResult 包含状态和统计信息
"""
if not ref_id:
ref_id = uuid.uuid4().hex
mode = "试运行" if is_pilot_run else "正式"
pipeline_start = time.time()
logger.info(
f"[WritePipeline] 开始 ({mode}) "
f"config={self.memory_config.config_name}, "
f"end_user={self.end_user_id}"
)
try:
# 初始化客户端和连接
self._init_clients()
self._init_neo4j_connector()
# Step 1: 预处理 - 消息分块
step_start = time.time()
chunked_dialogs = await self._preprocess(messages, ref_id)
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
logger.info(
f"[WritePipeline] [1/5] 预处理:消息分块 "
f"{time.time() - step_start:.2f}s chunks={chunks_count}"
)
# Step 2: 萃取 - 知识提取
step_start = time.time()
extraction_result = await self._extract(
chunked_dialogs, is_pilot_run
)
stats = extraction_result.stats
logger.info(
f"[WritePipeline] [2/5] 萃取:知识提取 "
f"{time.time() - step_start:.2f}s "
f"entities={stats['entity_count']}, "
f"statements={stats['statement_count']}, "
f"relations={stats['relation_count']}"
)
# 试运行模式到此结束
if is_pilot_run:
elapsed = time.time() - pipeline_start
logger.info(
f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s"
)
return WriteResult(
status="pilot_complete",
extraction=extraction_result.stats,
elapsed_seconds=elapsed,
)
# Step 3: 存储 - 写入 Neo4j
step_start = time.time()
await self._store(extraction_result)
logger.info(
f"[WritePipeline] [3/5] 存储:写入 Neo4j "
f"{time.time() - step_start:.2f}s"
)
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
step_start = time.time()
await self._cluster(extraction_result)
logger.info(
f"[WritePipeline] [4/5] 聚类:增量更新社区 "
f"{time.time() - step_start:.2f}s mode=async"
)
# Step 5: 摘要 - 生成情景记忆摘要
step_start = time.time()
await self._summarize(chunked_dialogs)
logger.info(
f"[WritePipeline] [5/5] 摘要:生成情景记忆 "
f"{time.time() - step_start:.2f}s"
)
# 更新活动统计缓存
await self._update_stats_cache(extraction_result)
elapsed = time.time() - pipeline_start
logger.info(
f"[WritePipeline] 完成 ✔ {elapsed:.2f}s"
)
return WriteResult(
status="success",
extraction=extraction_result.stats,
elapsed_seconds=elapsed,
)
except Exception as e:
elapsed = time.time() - pipeline_start
logger.error(
f"[WritePipeline] 失败 ✘ {elapsed:.2f}s error={e}",
exc_info=True,
)
raise
finally:
await self._cleanup()
# ──────────────────────────────────────────────
# Step 1: 预处理
# ──────────────────────────────────────────────
async def _preprocess(
self, messages: List[dict], ref_id: str
) -> List[DialogData]:
"""
预处理:消息校验 → 语义剪枝 → 对话分块。
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
get_dialogs.py 内部已包含:
- 消息格式校验role/content 必填)
- 语义剪枝(根据 config 中 pruning_enabled 决定)
- DialogueChunker 分块
"""
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
return await get_chunked_dialogs(
chunker_strategy=self.memory_config.chunker_strategy,
end_user_id=self.end_user_id,
messages=messages,
ref_id=ref_id,
config_id=str(self.memory_config.config_id),
)
# ──────────────────────────────────────────────
# Step 2: 萃取
# ──────────────────────────────────────────────
async def _extract(
self,
chunked_dialogs: List[DialogData],
is_pilot_run: bool,
) -> ExtractionResult:
"""
萃取:初始化引擎 → 执行知识提取 → 返回结构化结果。
ExtractionOrchestrator 作为萃取引擎被调用,
Pipeline 不关心引擎内部的并行策略和提取细节。
"""
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
pipeline_config = get_pipeline_config(self.memory_config)
ontology_types = self._load_ontology_types()
orchestrator = ExtractionOrchestrator(
llm_client=self._llm_client,
embedder_client=self._embedder_client,
connector=self._neo4j_connector,
config=pipeline_config,
embedding_id=str(self.memory_config.embedding_model_id),
language=self.language,
ontology_types=ontology_types,
progress_callback=self.progress_callback,
)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
perceptual_nodes,
stmt_chunk_edges,
stmt_entity_edges,
entity_entity_edges,
perceptual_edges,
dialog_data_list,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=is_pilot_run)
return ExtractionResult(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
perceptual_nodes=perceptual_nodes,
stmt_chunk_edges=stmt_chunk_edges,
stmt_entity_edges=stmt_entity_edges,
entity_entity_edges=entity_entity_edges,
perceptual_edges=perceptual_edges,
dialog_data_list=dialog_data_list,
)
# ──────────────────────────────────────────────
# Step 3: 存储
# ──────────────────────────────────────────────
async def _store(self, result: ExtractionResult) -> None:
"""
存储:别名清洗 → Neo4j 写入(含死锁重试)。
错误策略:
- 别名清洗失败 → 警告日志,继续写入
- Neo4j 写入死锁 → 指数退避重试 3 次
- Neo4j 写入非死锁异常 → 直接抛出,中断流程
"""
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
# 1. 写入前别名清洗(失败不中断)
await self._clean_cross_role_aliases(result.entity_nodes)
# 2. Neo4j 写入(含死锁重试)
max_retries = 3
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=result.dialogue_nodes,
chunk_nodes=result.chunk_nodes,
statement_nodes=result.statement_nodes,
entity_nodes=result.entity_nodes,
perceptual_nodes=result.perceptual_nodes,
statement_chunk_edges=result.stmt_chunk_edges,
statement_entity_edges=result.stmt_entity_edges,
entity_edges=result.entity_entity_edges,
perceptual_edges=result.perceptual_edges,
connector=self._neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
return
# 写入返回 False部分失败
if attempt < max_retries - 1:
logger.warning(
f"Neo4j 写入部分失败,重试 ({attempt + 2}/{max_retries})"
)
await asyncio.sleep(1 * (attempt + 1))
else:
logger.error(
f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败"
)
except Exception as e:
if self._is_deadlock(e) and attempt < max_retries - 1:
logger.warning(
f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})"
)
await asyncio.sleep(1 * (attempt + 1))
else:
raise
# ──────────────────────────────────────────────
# Step 4: 聚类
# ──────────────────────────────────────────────
async def _cluster(self, result: ExtractionResult) -> None:
"""
聚类:提交 Celery 异步任务进行增量社区更新。
聚类不阻塞主写入流程,失败不影响写入结果。
通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。
"""
if not result.entity_nodes:
return
try:
from app.tasks import run_incremental_clustering
new_entity_ids = [e.id for e in result.entity_nodes]
task = run_incremental_clustering.apply_async(
kwargs={
"end_user_id": self.end_user_id,
"new_entity_ids": new_entity_ids,
"llm_model_id": (
str(self.memory_config.llm_model_id)
if self.memory_config.llm_model_id
else None
),
"embedding_model_id": (
str(self.memory_config.embedding_model_id)
if self.memory_config.embedding_model_id
else None
),
},
priority=3,
)
logger.info(
f"[Clustering] 增量聚类任务已提交 - "
f"task_id={task.id}, entity_count={len(new_entity_ids)}"
)
except Exception as e:
logger.error(
f"[Clustering] 提交聚类任务失败(不影响主流程): {e}",
exc_info=True,
)
# ──────────────────────────────────────────────
# Step 5: 摘要
# + entity_description
# ──────────────────────────────────────────────
async def _summarize(self, chunked_dialogs: List[DialogData]) -> None:
"""
摘要:生成情景记忆摘要 → 写入 Neo4j。
摘要生成失败不影响主流程try/except 吞掉异常)。
使用独立的 Neo4j 连接器,避免与主连接器的事务冲突。
"""
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
try:
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=self._llm_client,
embedder_client=self._embedder_client,
language=self.language,
)
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(
summaries, ms_connector
)
finally:
try:
await ms_connector.close()
except Exception:
pass
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
# ──────────────────────────────────────────────
# 辅助方法
# ──────────────────────────────────────────────
def _init_clients(self) -> None:
"""
从 MemoryConfig 构建 LLM 和 Embedding 客户端。
使用 MemoryClientFactory 工厂模式,需要短暂的 DB session 来
查询模型配置API key、base_url 等),查询完毕立即释放。
"""
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
with get_db_context() as db:
factory = MemoryClientFactory(db)
self._llm_client = factory.get_llm_client_from_config(
self.memory_config
)
self._embedder_client = factory.get_embedder_client_from_config(
self.memory_config
)
logger.info("LLM and embedding clients constructed")
def _init_neo4j_connector(self) -> None:
"""初始化 Neo4j 连接器。"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
self._neo4j_connector = Neo4jConnector()
def _load_ontology_types(self):
"""
加载本体类型配置。
如果 memory_config 中配置了 scene_id则从数据库加载
该场景关联的本体类型列表,用于指导三元组提取。
"""
if not self.memory_config.scene_id:
return None
try:
from app.core.memory.ontology_services.ontology_type_loader import (
load_ontology_types_for_scene,
)
from app.db import get_db_context
with get_db_context() as db:
ontology_types = load_ontology_types_for_scene(
scene_id=self.memory_config.scene_id,
workspace_id=self.memory_config.workspace_id,
db=db,
)
if ontology_types:
logger.info(
f"Loaded {len(ontology_types.types)} ontology types "
f"for scene_id: {self.memory_config.scene_id}"
)
return ontology_types
except Exception as e:
logger.warning(
f"Failed to load ontology types for scene_id "
f"{self.memory_config.scene_id}: {e}",
exc_info=True,
)
return None
async def _clean_cross_role_aliases(
self, entity_nodes: List[ExtractedEntityNode]
) -> None:
"""
清洗用户/AI助手实体之间的别名交叉污染。
从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
确保用户实体的 aliases 不包含 AI 助手的名字。
失败不中断主流程。
"""
try:
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
clean_cross_role_aliases,
fetch_neo4j_assistant_aliases,
)
neo4j_assistant_aliases = set()
if entity_nodes:
eu_id = entity_nodes[0].end_user_id
if eu_id:
neo4j_assistant_aliases = (
await fetch_neo4j_assistant_aliases(
self._neo4j_connector, eu_id
)
)
clean_cross_role_aliases(
entity_nodes,
external_assistant_aliases=neo4j_assistant_aliases,
)
logger.info(
f"别名清洗完成AI助手别名排除集大小: "
f"{len(neo4j_assistant_aliases)}"
)
except Exception as e:
logger.warning(f"别名清洗失败(不影响主流程): {e}")
@staticmethod
def _is_deadlock(e: Exception) -> bool:
"""判断异常是否为 Neo4j 死锁错误"""
msg = str(e).lower()
return "deadlockdetected" in msg or "deadlock" in msg
async def _update_stats_cache(
self, result: ExtractionResult
) -> None:
"""
将提取统计写入 Redis 活动缓存,按 workspace_id 存储。
失败不中断主流程。
"""
try:
from app.cache.memory.activity_stats_cache import (
ActivityStatsCache,
)
stats = {
"chunk_count": result.stats["chunk_count"],
"statements_count": result.stats["statement_count"],
"triplet_entities_count": result.stats["entity_count"],
"triplet_relations_count": result.stats["relation_count"],
"temporal_count": 0,
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(self.memory_config.workspace_id),
stats=stats,
)
logger.info(
f"活动统计已写入 Redis: "
f"workspace_id={self.memory_config.workspace_id}"
)
except Exception as e:
logger.warning(
f"写入活动统计缓存失败(不影响主流程): {e}"
)
async def _cleanup(self) -> None:
"""
清理资源:关闭 Neo4j 连接器和 HTTP 客户端。
在 run() 的 finally 块中调用,确保资源释放。
"""
# 关闭 Neo4j 连接器
if self._neo4j_connector:
try:
await self._neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
# 关闭 LLM/Embedder 底层 httpx 客户端
# 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发
for client_obj in (self._llm_client, self._embedder_client):
try:
underlying = getattr(
client_obj, "client", None
) or getattr(client_obj, "model", None)
if underlying is None:
continue
inner = getattr(underlying, "_model", underlying)
http_client = getattr(inner, "async_client", None)
if http_client is not None and hasattr(
http_client, "aclose"
):
await http_client.aclose()
except Exception:
pass

View File

@@ -0,0 +1,184 @@
"""
BearLogger — 结构化任务日志工具
在大量中间模块日志中提供醒目的 Pipeline 步骤进度标记。
基于标准 logging.Logger不修改现有日志配置。
设计要点:
- 每个 step 只输出一行完成日志(不输出"开始"行,减少噪音)
- Pipeline 开始/结束用 ═══ 粗分隔线,在终端中一眼可辨
- step 完成行用 ▶ 图标 + 固定宽度对齐,紧凑且整齐
- 性能告警用 ⚡ 标记,超过阈值自动触发
"""
from __future__ import annotations
import logging
import time
import uuid
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import Any, Dict, Optional
# ── 上下文变量(线程/协程安全)──
_trace_id: ContextVar[str] = ContextVar("bear_trace_id", default="")
# ── 默认性能阈值(秒)──
DEFAULT_PERF_THRESHOLDS: Dict[str, float] = {
"预处理": 10,
"萃取": 60,
"存储": 30,
"聚类": 5,
"摘要": 30,
}
class _StepScope:
"""Step 作用域,持有单步的状态和元数据。"""
def __init__(
self,
logger: logging.Logger,
index: int,
total: int,
category: str,
description: str,
threshold: Optional[float] = None,
):
self._logger = logger
self._index = index
self._total = total
self._category = category
self._description = description
self._threshold = threshold
self._start_time = 0.0
self._kv: Dict[str, Any] = {}
def metadata(self, **kv: Any) -> None:
"""附加元数据,会在完成日志的行尾展示。"""
self._kv.update(kv)
def _start(self) -> None:
self._start_time = time.time()
def _succeed(self) -> None:
elapsed = time.time() - self._start_time
# 性能告警
if self._threshold and elapsed > self._threshold:
status = f"{elapsed:.2f}s [SLOW]"
else:
status = f"{elapsed:.2f}s"
# 元数据
kv_str = ""
if self._kv:
kv_str = " " + ", ".join(f"{k}={v}" for k, v in self._kv.items())
self._logger.info(
f" ▶ [{self._index}/{self._total}] "
f"{self._category}{self._description} "
f"── {status}{kv_str}"
)
def _fail(self, error: Exception) -> None:
elapsed = time.time() - self._start_time
self._logger.error(
f" ✘ [{self._index}/{self._total}] "
f"{self._category}{self._description} "
f"── FAILED {elapsed:.2f}s error={error}"
)
class BearLogger:
"""结构化任务日志工具。
用法::
bear = BearLogger("memory.pipeline")
async with bear.pipeline("WritePipeline", mode="正式"):
async with bear.step(1, 5, "预处理", "消息分块") as s:
result = await preprocess()
s.metadata(chunks=3)
"""
def __init__(
self,
name: str = "memory.pipeline",
perf_thresholds: Optional[Dict[str, float]] = None,
):
self._logger = logging.getLogger(name)
self._thresholds = perf_thresholds or DEFAULT_PERF_THRESHOLDS
@asynccontextmanager
async def pipeline(self, name: str, **context_kv: Any):
"""Pipeline 级作用域。开始和结束用醒目的分隔线。"""
trace_id = uuid.uuid4().hex[:8]
token = _trace_id.set(trace_id)
start = time.time()
ctx_parts = [f"{k}={v}" for k, v in context_kv.items()]
ctx_str = ", ".join(ctx_parts)
self._logger.info(
f"{'' * 60}\n"
f" 🚀 {name} 开始 {ctx_str}\n"
f"{'' * 60}"
)
error = None
try:
yield self
except Exception as e:
error = e
raise
finally:
elapsed = time.time() - start
if error:
self._logger.error(
f"{'' * 60}\n"
f"{name} 失败 ({elapsed:.2f}s) error={error}\n"
f"{'' * 60}"
)
else:
self._logger.info(
f"{'' * 60}\n"
f"{name} 完成 ({elapsed:.2f}s)\n"
f"{'' * 60}"
)
_trace_id.reset(token)
@asynccontextmanager
async def step(
self,
index: int,
total: int,
category: str,
description: str,
):
"""Step 级作用域。只在完成时输出一行日志(减少噪音)。"""
scope = _StepScope(
logger=self._logger,
index=index,
total=total,
category=category,
description=description,
threshold=self._thresholds.get(category),
)
scope._start()
try:
yield scope
except Exception as e:
scope._fail(e)
raise
else:
scope._succeed()
def info(self, message: str, **kv: Any) -> None:
"""带缩进的 info 日志。"""
suffix = ""
if kv:
suffix = " " + ", ".join(f"{k}={v}" for k, v in kv.items())
self._logger.info(f"{message}{suffix}")

View File

@@ -97,7 +97,7 @@ async def render_statement_extraction_prompt(
}) })
return rendered_prompt return rendered_prompt
# TODO temporal与statement prompt合并在一起以下代码不需要
async def render_temporal_extraction_prompt( async def render_temporal_extraction_prompt(
ref_dates: dict, ref_dates: dict,
statement: dict, statement: dict,
@@ -198,6 +198,7 @@ def render_entity_dedup_prompt(
# Args: # Args:
# entity_a: Dict of entity A attributes # entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt( async def render_triplet_extraction_prompt(
statement: str, statement: str,
chunk_content: str, chunk_content: str,

View File

@@ -2,8 +2,7 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator from typing import Generator
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy.ext.declarative import declarative_base
from app.core.config import settings from app.core.config import settings
SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}"

View File

@@ -17,10 +17,9 @@ async def create_fulltext_indexes():
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
# """) # """)
# 创建 Entities 索引 # 创建 Entities 索引 (name + description + aliases)
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name, e.description, e.aliases]
FOR (e:ExtractedEntity) ON EACH [e.name, e.description, e.aliases]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)

View File

@@ -4,7 +4,7 @@ Order Schema
Defines request and response models for order operations. Defines request and response models for order operations.
""" """
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from typing import Any, Optional from typing import Any, Optional
@@ -17,8 +17,8 @@ class CreateOrderRequest(BaseModel):
customer_email: Optional[str] = Field(None, description="Customer email") customer_email: Optional[str] = Field(None, description="Customer email")
notes: Optional[str] = Field(None, description="Order notes") notes: Optional[str] = Field(None, description="Order notes")
class Config: model_config = ConfigDict(
json_schema_extra = { json_schema_extra={
"example": { "example": {
"product_id": "PROD-001", "product_id": "PROD-001",
"quantity": 2, "quantity": 2,
@@ -27,6 +27,7 @@ class CreateOrderRequest(BaseModel):
"notes": "Please deliver before 5pm" "notes": "Please deliver before 5pm"
} }
} }
)
class OrderResponse(BaseModel): class OrderResponse(BaseModel):
@@ -40,8 +41,8 @@ class OrderResponse(BaseModel):
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
message: Optional[str] = Field(None, description="Response message") message: Optional[str] = Field(None, description="Response message")
class Config: model_config = ConfigDict(
json_schema_extra = { json_schema_extra={
"example": { "example": {
"order_id": "ORD-20231224-001", "order_id": "ORD-20231224-001",
"status": "pending", "status": "pending",
@@ -52,6 +53,7 @@ class OrderResponse(BaseModel):
"message": "Order created successfully" "message": "Order created successfully"
} }
} }
)
class ExternalOrderResponse(BaseModel): class ExternalOrderResponse(BaseModel):

View File

@@ -1,5 +1,5 @@
from dataclasses import field from dataclasses import field
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict from pydantic import BaseModel, EmailStr, Field, field_validator, ConfigDict
from typing import Optional, List from typing import Optional, List
import datetime import datetime
import uuid import uuid
@@ -90,7 +90,8 @@ class User(UserBase):
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制 permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
# 将 datetime 转换为毫秒时间戳 # 将 datetime 转换为毫秒时间戳
@validator("created_at", pre=True) @field_validator("created_at", mode="before")
@classmethod
def _created_at_to_ms(cls, v): def _created_at_to_ms(cls, v):
if isinstance(v, datetime.datetime): if isinstance(v, datetime.datetime):
return int(v.timestamp() * 1000) return int(v.timestamp() * 1000)

View File

@@ -367,6 +367,33 @@ class MemoryAgentService:
ref_id='', ref_id='',
language=language language=language
) )
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
import os
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
try:
from app.core.memory.memory_service import MemoryService
import copy
shadow_messages = copy.deepcopy(messages)
shadow_service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
shadow_result = await shadow_service.write(
messages=shadow_messages,
language=language,
ref_id='',
is_pilot_run=True, # 试运行模式:只萃取不写入,避免重复写入 Neo4j
)
logger.info(
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
f"extraction={shadow_result.extraction}"
)
except Exception as shadow_err:
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
# ── 影子运行结束 ──
for lang in ["zh", "en"]: for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution( deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang end_user_id, lang