refactor(memory): consolidate write pipeline and rename statement extraction step

- Rename StatementExtractionStep → StatementTemporalExtractionStep and
  extract_statement.jinja2 → extract_statement_temporal.jinja2 to reflect
  merged temporal extraction logic
- Move extraction_pipeline_orchestrator.py out of steps/ to engine root
- Move dedup_step.py into steps/ directory
- Introduce WriteMemoryRequest schema to replace positional args in write_memory()
- Extract _resolve_and_load_config, _preprocess_files, _write_neo4j, and
  _invalidate_interest_cache as private helpers in MemoryAgentService
- Remove shadow pipeline and simplify NEW_PIPELINE_ENABLED branch
- Merge 类型归属/成员隶属/任职服务 relation types into single 归属身份关系 in triplet prompt
- Add alias merge logic (别名属于) in deduplication and MERGE_ALIAS_BELONGS_TO Cypher query
- Add StorageType, Language, MessageItem enums/models to memory_agent_schema
- Reduce AgentMemory_Long_Term.DEFAULT_SCOPE from 6 to 1
- Delete standalone extract_temporal.jinja2 (logic merged into statement step)
This commit is contained in:
lanceyq
2026-04-28 16:36:30 +08:00
parent 7747ed7ac1
commit 1f0c88a5f0
17 changed files with 390 additions and 326 deletions

View File

@@ -12,11 +12,11 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.dedup_step import ( from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
DedupResult, DedupResult,
run_dedup, run_dedup,
) )
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator, NewExtractionOrchestrator,
) )
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (

View File

@@ -227,6 +227,14 @@ class WritePipeline:
f"{time.time() - step_start:.2f}s" f"{time.time() - step_start:.2f}s"
) )
# Step 3.2: 别名归并 - 处理 predicate="别名属于" 的关系
step_start = time.time()
await self._merge_alias_belongs_to(extraction_result)
logger.info(
f"[WritePipeline] [3.2] 别名归并:处理别名属于关系 "
f"{time.time() - step_start:.2f}s"
)
# Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在) # Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在)
await self._extract_emotion(getattr(self, "_emotion_statements", [])) await self._extract_emotion(getattr(self, "_emotion_statements", []))
@@ -316,13 +324,13 @@ class WritePipeline:
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边 2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
3. run_dedup() → 两阶段去重消歧 3. run_dedup() → 两阶段去重消歧
""" """
from app.core.memory.storage_services.extraction_engine.dedup_step import ( from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
run_dedup, run_dedup,
) )
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
build_graph_nodes_and_edges, build_graph_nodes_and_edges,
) )
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator, NewExtractionOrchestrator,
) )
@@ -554,6 +562,78 @@ class WritePipeline:
else: else:
raise raise
# ──────────────────────────────────────────────
# Step 3.2: 别名归并
# ──────────────────────────────────────────────
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None:
"""
所有去重合并都可以使用这个这种统一的处理方式(未实现)
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
- 将 source.name 追加到 target.aliases去重
- 将 source.description append 进 target.description_listlist 形式)
同时在内存中同步更新 ExtractionResult.entity_nodes保持内存与 Neo4j 一致。
失败不中断主流程。
"""
from app.repositories.neo4j.cypher_queries import MERGE_ALIAS_BELONGS_TO
ALIAS_PREDICATE = "别名属于"
# 筛选出所有 predicate="别名属于" 的边
alias_edges = [
e for e in result.entity_entity_edges
if getattr(e, "relation_type", "") == ALIAS_PREDICATE
or getattr(e, "predicate", "") == ALIAS_PREDICATE
]
if not alias_edges:
logger.debug("[AliasMerge] 无 '别名属于' 关系,跳过")
return
try:
# ── 1. 在内存中同步更新 entity_nodes ──
entity_map = {e.id: e for e in result.entity_nodes}
for edge in alias_edges:
source_node = entity_map.get(edge.source)
target_node = entity_map.get(edge.target)
if not source_node or not target_node:
continue
# 将 source.name 追加到 target.aliases去重忽略大小写
source_name = (source_node.name or "").strip()
if source_name:
existing_lower = {a.lower() for a in (target_node.aliases or [])}
if source_name.lower() not in existing_lower:
target_node.aliases = list(target_node.aliases or []) + [source_name]
# 将 source.description append 进 target.description追加分号分隔
src_desc = (source_node.description or "").strip()
if src_desc:
tgt_desc = (target_node.description or "").strip()
if src_desc not in tgt_desc:
target_node.description = f"{tgt_desc}{src_desc}" if tgt_desc else src_desc
logger.info(
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)}'别名属于'"
)
# ── 2. 写入 Neo4j ──
records = await self._neo4j_connector.execute_query(
MERGE_ALIAS_BELONGS_TO,
end_user_id=self.end_user_id,
)
merged_count = len(records) if records else 0
logger.info(
f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录"
)
except Exception as e:
logger.warning(f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True)
# ────────────────────────────────────────────── # ──────────────────────────────────────────────
# Step 4: 聚类 # Step 4: 聚类
# ────────────────────────────────────────────── # ──────────────────────────────────────────────

View File

@@ -1112,6 +1112,39 @@ async def deduplicate_entities_and_edges(
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方 # 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID # 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
# 4) 边重定向与去重 # 4) 边重定向与去重
# 4.0 预处理:将 "别名属于" 关系的 source.name/description 归并到 target 节点
# 必须在边重定向之前执行,此时 id_redirect 已包含精确/模糊/LLM 的合并结果
try:
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
for edge in entity_entity_edges:
if getattr(edge, "relation_type", "") != "别名属于":
continue
# 通过 id_redirect 找到合并后的规范节点
source_id = id_redirect.get(edge.source, edge.source)
target_id = id_redirect.get(edge.target, edge.target)
if source_id == target_id:
continue
source_node = entity_by_id.get(source_id)
target_node = entity_by_id.get(target_id)
if not source_node or not target_node:
continue
# 将 source.name 追加到 target.aliases去重忽略大小写
source_name = (source_node.name or "").strip()
if source_name:
existing_lower = {a.lower() for a in (target_node.aliases or [])}
if source_name.lower() not in existing_lower and source_name.lower() != (target_node.name or "").lower():
target_node.aliases = list(target_node.aliases or []) + [source_name]
# 将 source.description 追加到 target.description分号分隔去重
src_desc = (source_node.description or "").strip()
if src_desc:
tgt_desc = (target_node.description or "").strip()
if src_desc not in tgt_desc:
target_node.description = f"{tgt_desc}{src_desc}" if tgt_desc else src_desc
except Exception:
pass
# 4.1 语句→实体边:重复时优先保留 strong # 4.1 语句→实体边:重复时优先保留 strong
stmt_ent_map: Dict[str, StatementEntityEdge] = {} stmt_ent_map: Dict[str, StatementEntityEdge] = {}
for edge in statement_entity_edges: for edge in statement_entity_edges:

View File

@@ -23,12 +23,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig from app.core.memory.models.variate_config import ExtractionPipelineConfig
from .base import ExtractionStep, StepContext from .steps.base import ExtractionStep, StepContext
from .embedding_step import EmbeddingStep from .steps.embedding_step import EmbeddingStep
from .sidecar_factory import SidecarStepFactory, SidecarTiming from .steps.sidecar_factory import SidecarStepFactory, SidecarTiming
from .statement_step import StatementExtractionStep from .steps.statement_temporal_step import StatementTemporalExtractionStep
from .triplet_step import TripletExtractionStep from .steps.triplet_step import TripletExtractionStep
from .schema import ( from .steps.schema import (
EmbeddingStepInput, EmbeddingStepInput,
EmbeddingStepOutput, EmbeddingStepOutput,
EmotionStepInput, EmotionStepInput,
@@ -85,7 +85,7 @@ class NewExtractionOrchestrator:
) )
# ── Critical (main-line) steps ── # ── Critical (main-line) steps ──
self.statement_step = StatementExtractionStep(self.context) self.statement_temporal_step = StatementTemporalExtractionStep(self.context)
self.triplet_step = TripletExtractionStep( self.triplet_step = TripletExtractionStep(
self.context, ontology_types=ontology_types self.context, ontology_types=ontology_types
) )
@@ -508,7 +508,7 @@ class NewExtractionOrchestrator:
), ),
supporting_context=ctx, supporting_context=ctx,
) )
tasks.append(self.statement_step.run(inp)) tasks.append(self.statement_temporal_step.run(inp))
task_meta.append( task_meta.append(
(dialog.id, chunk.id, chunk_speaker, ctx) (dialog.id, chunk.id, chunk_speaker, ctx)
) )

View File

@@ -7,10 +7,10 @@ for all sidecar (non-critical) steps via SidecarStepFactory.
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401 from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
# Step implementations — importing triggers @register self-registration. # Step implementations — importing triggers @register self-registration.
from .statement_step import StatementExtractionStep # noqa: F401 from .statement_temporal_step import StatementTemporalExtractionStep # noqa: F401
from .triplet_step import TripletExtractionStep # noqa: F401 from .triplet_step import TripletExtractionStep # noqa: F401
from .emotion_step import EmotionExtractionStep # noqa: F401 from .emotion_step import EmotionExtractionStep # noqa: F401
from .embedding_step import EmbeddingStep # noqa: F401 from .embedding_step import EmbeddingStep # noqa: F401
# Refactored orchestrator # Refactored orchestrator
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401 from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401

View File

@@ -28,7 +28,7 @@ class SupportingContext(BaseModel):
# ── Statement extraction ── # ── Statement extraction ──
class StatementStepInput(BaseModel): class StatementStepInput(BaseModel):
"""Input for StatementExtractionStep.""" """Input for StatementTemporalExtractionStep."""
chunk_id: str chunk_id: str
end_user_id: str end_user_id: str

View File

@@ -1,4 +1,4 @@
"""StatementExtractionStep — critical step for extracting statements from chunks. """StatementTemporalExtractionStep — critical step for extracting statements and temporal info from chunks.
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm. Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
Temporal extraction logic (valid_at / invalid_at) is merged into this step, Temporal extraction logic (valid_at / invalid_at) is merged into this step,
@@ -62,8 +62,8 @@ class _StatementExtractionResponse(BaseModel):
return v return v
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]): class StatementTemporalExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
"""Extract atomic statements (with temporal info) from a dialogue chunk. """Extract atomic statements with temporal info (valid_at / invalid_at) from a dialogue chunk.
This is a **critical** step failure aborts the pipeline after retries. This is a **critical** step failure aborts the pipeline after retries.

View File

@@ -65,7 +65,7 @@ async def render_statement_extraction_prompt(
Returns: Returns:
Rendered prompt content as string Rendered prompt content as string
""" """
template = prompt_env.get_template("extract_statement.jinja2") template = prompt_env.get_template("extract_statement_temporal.jinja2")
# Optional clipping of dialogue context # Optional clipping of dialogue context
ctx = None ctx = None
if include_dialogue_context and dialogue_content: if include_dialogue_context and dialogue_content:
@@ -90,7 +90,7 @@ async def render_statement_extraction_prompt(
# 记录渲染结果到提示日志(与示例日志结构一致) # 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('statement extraction', rendered_prompt) log_prompt_rendering('statement extraction', rendered_prompt)
# 可选:记录模板渲染信息 # 可选:记录模板渲染信息
log_template_rendering('extract_statement.jinja2', { log_template_rendering('extract_statement_temporal.jinja2', {
'inputs': 'chunk', 'inputs': 'chunk',
'definitions': 'LABEL_DEFINITIONS', 'definitions': 'LABEL_DEFINITIONS',
'json_schema': 'StatementExtractionResponse.schema', 'json_schema': 'StatementExtractionResponse.schema',

View File

@@ -1,126 +0,0 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
{#
This prompt (template) is adapted from [getzep/graphiti]
Licensed under the Apache License, Version 2.0
Original work:
https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py
Modifications made by Ke Sun on 2025-09-01
See the LICENSE file for the full Apache 2.0 license text.
#}
# Task
{% if language == "zh" %}
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
{% else %}
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
{% endif %}
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
{% if inputs %}
{% for key, val in inputs.items() %}
- {{ key }}: {{val}}
{% endfor %}
{% endif %}
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
{% if language == "zh" %}
- **valid_at**: 关系/事件开始或成为真实的时间ISO 8601 格式)
- **invalid_at**: 关系/事件结束或停止为真的时间ISO 8601 格式,如果正在进行则为 null
{% else %}
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
{% endif %}
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
{% if language == "zh" %}
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
2. **使用参考/发布日期作为"现在"** 解释相对时间时
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
4. **对于时间点事件**,仅设置 `valid_at`
{% else %}
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
2. **Use the reference/publication date as "now"** when interpreting relative times
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
4. **For point-in-time events**, set only `valid_at`
{% endif %}
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
{% if language == "zh" %}
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- 如果未指定时间,使用 `00:00:00`(午夜)
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
- 如果仅提及月份,使用月份的第一天或最后一天
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
{% else %}
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- If no time specified, use `00:00:00` (midnight)
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
- If only month mentioned, use first or last day of month
- Always include timezone (use `Z` for UTC if unspecified)
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
{% endif %}
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
{%for key, guide in statement_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
{% if language == "zh" %}
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
{% else %}
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
{% endif %}
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
{% for key, guide in temporal_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
{% if inputs.get('quarter') and inputs.get('publication_date') %}
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
{% if language == "zh" %}
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用Q1、Q2 等)的日期。
{% else %}
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
{% endif %}
{% endif %}
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
{% if language == "zh" %}
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
2. 使用反斜杠转义内部引号: `\"`
3. JSON 字符串值中不要有换行符
4. 正确关闭并用逗号分隔所有字段
{% else %}
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
2. Escape internal quotes with backslash: `\"`
3. No line breaks within JSON string values
4. Properly close and comma-separate all fields
{% endif %}
## {% if language == "zh" %}语言{% else %}Language{% endif %}
{% if language == "zh" %}
输出语言必须与输入语言匹配。
{% else %}
Output language must match input language.
{% endif %}
{{ json_schema }}

View File

@@ -250,28 +250,12 @@ Primary statement to analyze:
- notes: 只处理名字性表达,不处理角色、职业、评价词。 - notes: 只处理名字性表达,不处理角色、职业、评价词。
- status: `enabled` - status: `enabled`
- `类型归属关系` - `归属身份关系`
- definition: 表达实体属于某种类别,或主体承担某种角色/职业身份的关系。 - definition: 表达主体所属的类别、身份、职业、角色,或其与组织、群体、集合之间的归属关系。
- covered_predicates: `属于类型`、`担任角色`、`从事职业` - covered_predicates: `属于类型`、`担任角色`、`从事职业`、`成员属于`、`任职于`
- positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员` - positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`、`张三 -> 成员属于 -> 实验室成员`、`张明 -> 任职于 -> 腾讯`
- negative_examples: `张三 -> 担任角色 -> 山哥`、`用户 -> 从事职业 -> 紧张` - negative_examples: `张三 -> 担任角色 -> 山哥`、`他们 -> 成员属于 -> 学校`、`用户 -> 任职于 -> 明天的面试`、`用户 -> 从事职业 -> 紧张`
- notes: 用于“是什么”,不用于“叫什么” - notes: 这是一个上位父类,用于统一承接“是什么身份”与“归属哪里”两类关系。第一层不再强行区分“身份类归属”和“组织类归属”,真正的区分在子类 predicate 层完成
- status: `enabled`
- `成员隶属关系`
- definition: 表达主体属于某个组织、群体或集合的成员归属关系。
- covered_predicates: `成员属于`
- positive_examples: `张三 -> 成员属于 -> 实验室成员`、`用户 -> 成员属于 -> 社群`
- negative_examples: `他们 -> 成员属于 -> 学校`、`一个朋友 -> 成员属于 -> 班级`
- notes: 前提是主体和归属对象都足够稳定;边界不稳的人群不要硬抽。
- status: `enabled`
- `任职服务关系`
- definition: 表达人物或主体在组织中的工作、任职或服务关系。
- covered_predicates: `任职于`
- positive_examples: `张明 -> 任职于 -> 腾讯`、`王教授 -> 任职于 -> 清华大学`
- negative_examples: `张明 -> 任职于 -> 导师`、`用户 -> 任职于 -> 明天的面试`
- notes: 优先用于人物到组织的稳定供职关系。
- status: `enabled` - status: `enabled`
- `空间位置关系` - `空间位置关系`

View File

@@ -1107,6 +1107,35 @@ RETURN (
) AS is_complete ) AS is_complete
""" """
# 别名归并:将 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边的 source.name
# 合并进 target.aliases去重并将 source.description 追加到 target.description分号分隔
MERGE_ALIAS_BELONGS_TO = """
MATCH (source:ExtractedEntity {end_user_id: $end_user_id})-[r:EXTRACTED_RELATIONSHIP]->(target:ExtractedEntity {end_user_id: $end_user_id})
WHERE r.predicate = '别名属于'
WITH source, target,
coalesce(target.aliases, []) AS existing_aliases,
source.name AS source_name,
coalesce(source.description, '') AS src_desc,
coalesce(target.description, '') AS tgt_desc
// 1. 合并 aliases将 source.name 追加到 target.aliases去重
WITH source, target, src_desc, tgt_desc,
CASE
WHEN source_name IS NOT NULL AND source_name <> '' AND NOT source_name IN existing_aliases
THEN existing_aliases + source_name
ELSE existing_aliases
END AS new_aliases
SET target.aliases = new_aliases,
target.description = CASE
WHEN src_desc <> '' AND NOT src_desc IN tgt_desc
THEN CASE WHEN tgt_desc = '' THEN src_desc ELSE tgt_desc + '' + src_desc END
ELSE tgt_desc
END
RETURN source.name AS merged_alias, target.name AS target_name, new_aliases AS updated_aliases
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN ( RETURN (

View File

@@ -1,9 +1,32 @@
from abc import ABC from abc import ABC
from typing import Optional from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
class StorageType(str, Enum):
"""记忆存储后端类型"""
NEO4J = "neo4j"
RAG = "rag"
class Language(str, Enum): # 没有传递到聚类的celery任务中去任务会回退失败用默认值考虑统一语言问题
"""支持的语言"""
ZH = "zh"
EN = "en"
class MessageItem(BaseModel):
"""单条消息结构"""
role: str
content: str
files: Optional[list[dict]] = None
file_content: Optional[list[Any]] = None
model_config = {"extra": "allow"}
class UserInput(BaseModel): class UserInput(BaseModel):
message: str message: str
history: list[dict] history: list[dict]
@@ -18,6 +41,16 @@ class Write_UserInput(BaseModel):
config_id: Optional[str] = None config_id: Optional[str] = None
class WriteMemoryRequest(BaseModel):
"""write_memory() 的参数封装"""
end_user_id: str
messages: list[MessageItem]
config_id: Optional[Any] = None
storage_type: StorageType = StorageType.NEO4J
user_rag_memory_id: str = ""
language: Language = Language.ZH
class AgentMemory_Long_Term(ABC): class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量""" """长期记忆配置常量"""
STORAGE_NEO4J = "neo4j" STORAGE_NEO4J = "neo4j"
@@ -25,7 +58,7 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_AGGREGATE = "aggregate" STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk" STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time" STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6 DEFAULT_SCOPE = 1
TIME_SCOPE = 5 TIME_SCOPE = 5

View File

@@ -36,13 +36,14 @@ from app.core.memory.agent.utils.messages_tools import (
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.memory_service import MemoryService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Language, MessageItem, StorageType, Write_UserInput, WriteMemoryRequest
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
@@ -267,25 +268,15 @@ class MemoryAgentService:
async def write_memory( async def write_memory(
self, self,
end_user_id: str, request: WriteMemoryRequest,
messages: list[dict],
config_id: Optional[uuid.UUID] | int,
db: Session, db: Session,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> str: ) -> str:
""" """
Process write operation with config_id 长期记忆写入
Args: Args:
end_user_id: Group identifier (also used as end_user_id) request: 写入请求参数end_user_id、messages、config_id、storage_type、language 等)
messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
language: 语言类型 ("zh" 中文, "en" 英文)
Returns: Returns:
Write operation result status Write operation result status
@@ -293,147 +284,50 @@ class MemoryAgentService:
Raises: Raises:
ValueError: If config loading fails or write operation fails ValueError: If config loading fails or write operation fails
""" """
# Resolve config_id and workspace_id end_user_id = request.end_user_id
# Always get workspace_id from end_user for fallback, even if config_id is provided messages = request.messages
workspace_id = None config_id = request.config_id
try: storage_type = request.storage_type
connected_config = get_end_user_connected_config(end_user_id, db) user_rag_memory_id = request.user_rag_memory_id
workspace_id = connected_config.get("workspace_id") language = request.language
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
import time
start_time = time.time() start_time = time.time()
# Load configuration from database with workspace fallback # ── Step 1: 解析配置 ── 通过 end_user_id 查找关联的 config_id / workspace_id并从数据库加载完整 memory_config
# Use a separate database session to avoid transaction failures memory_config = await self._resolve_and_load_config(
end_user_id, config_id, db, start_time
)
# ── Step 2: 文件预处理 ── 将消息中附带的文件转换为感知记忆对象,挂载到 message["file_content"]
messages = await self._preprocess_files(messages, end_user_id, memory_config, db)
message_text = "\n".join([
f"{(msg['role'] if isinstance(msg, dict) else msg.role)}: {(msg['content'] if isinstance(msg, dict) else msg.content)}"
for msg in messages
])
# ── Step 3: 写入存储 ── 根据 storage_type 分流到 RAG 或 Neo4j 流水线
try: try:
from app.db import get_db_context if storage_type == StorageType.RAG:
with get_db_context() as config_db:
config_service = MemoryConfigService(config_db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
for message in messages:
message["file_content"] = []
for file in (message.get("files") or []):
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
if file_object is None:
continue
message["file_content"].append((file_object, file["type"]))
logger.info(messages)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
await write_rag(end_user_id, message_text, user_rag_memory_id) await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success" return "success"
else: else:
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码 await self._write_neo4j(end_user_id, messages, memory_config, language)
import os
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
if use_new_pipeline: # ── Step 4: 后处理 ── 失效缓存、序列化文件路径、记录审计日志并返回结果
# ── 新流水线WritePipeline + NewExtractionOrchestrator ── await self._invalidate_interest_cache(end_user_id)
from app.core.memory.memory_service import MemoryService
service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
result = await service.write(
messages=messages,
language=language,
ref_id='',
is_pilot_run=False,
)
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
# ── 旧流水线write_tools.write() + ExtractionOrchestrator ──
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
memory_config=memory_config,
ref_id='',
language=language
)
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
try:
from app.core.memory.memory_service import MemoryService
import copy
shadow_messages = copy.deepcopy(messages)
shadow_service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
shadow_result = await shadow_service.write(
messages=shadow_messages,
language=language,
ref_id='',
is_pilot_run=True,
)
logger.info(
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
f"extraction={shadow_result.extraction}"
)
except Exception as shadow_err:
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
# ── 影子运行结束 ──
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang
)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
for message in messages: for message in messages:
message["file_content"] = [ if isinstance(message, dict):
perceptual[0].file_path for perceptual in message["file_content"] message["file_content"] = [
] perceptual[0].file_path for perceptual in (message["file_content"] or [])
]
else:
message.file_content = [
perceptual[0].file_path for perceptual in (message.file_content or [])
]
return self.writer_messages_deal( return self.writer_messages_deal(
"success", "success",
start_time, start_time,
end_user_id, end_user_id,
config_id, memory_config.config_id,
message_text, message_text,
{ {
"status": "success", "status": "success",
@@ -443,15 +337,139 @@ class MemoryAgentService:
} }
) )
except Exception as e: except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
audit_logger.log_operation(
duration = time.time() - start_time operation="WRITE",
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, config_id=memory_config.config_id,
success=False, duration=duration, error=error_msg) end_user_id=end_user_id,
success=False,
duration=time.time() - start_time,
error=error_msg,
)
raise ValueError(error_msg) raise ValueError(error_msg)
async def _resolve_and_load_config(
self,
end_user_id: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
start_time: float,
):
"""解析 end_user 关联配置并从数据库加载完整 memory_config。"""
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get("workspace_id")
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration."
)
except Exception as e:
if "No memory configuration found" in str(e):
raise
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
try:
with get_db_context() as config_db:
memory_config = MemoryConfigService(config_db).load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService",
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
return memory_config
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=time.time() - start_time,
error=error_msg,
)
raise ValueError(error_msg)
async def _preprocess_files(
self,
messages: list[MessageItem] | list[dict],
end_user_id: str,
memory_config,
db: Session,
) -> list[dict]:
"""处理消息中附带的文件,生成感知记忆对象并挂载到 message['file_content']。"""
perceptual_service = MemoryPerceptualService(db)
for message in messages:
if isinstance(message, dict):
message["file_content"] = []
files = message.get("files") or []
else:
message.file_content = []
files = message.files or []
for file in files:
file_object = await perceptual_service.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file),
)
if file_object is None:
continue
if isinstance(message, dict):
message["file_content"].append((file_object, file["type"]))
else:
message.file_content.append((file_object, file["type"]))
logger.info(messages)
return messages
async def _write_neo4j(
self,
end_user_id: str,
messages: list[MessageItem] | list[dict],
memory_config,
language: Language | str,
) -> None:
"""根据 NEW_PIPELINE_ENABLED 选择新旧流水线写入 Neo4j。"""
# 统一转换为 dict下游流水线期望 list[dict]
messages_dict = [
msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True)
for msg in messages
]
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
if use_new_pipeline:
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
result = await service.write(messages=messages_dict, language=language, ref_id='')
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
await write_neo4j(
end_user_id=end_user_id,
messages=messages_dict,
memory_config=memory_config,
ref_id='',
language=language,
)
async def _invalidate_interest_cache(self, end_user_id: str) -> None:
"""写入完成后失效兴趣分布缓存。"""
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(end_user_id, lang)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}"
)
async def read_memory( async def read_memory(
self, self,
end_user_id: str, end_user_id: str,

View File

@@ -17,6 +17,7 @@ from app.core.logging_config import get_logger
from app.models.app_model import App from app.models.app_model import App
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.schemas.memory_agent_schema import WriteMemoryRequest
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -291,12 +292,14 @@ class MemoryAPIService:
try: try:
messages = message if isinstance(message, list) else [{"role": "user", "content": message}] messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
result = await MemoryAgentService().write_memory( result = await MemoryAgentService().write_memory(
end_user_id=end_user_id, WriteMemoryRequest(
messages=messages, end_user_id=end_user_id,
config_id=config_id, messages=messages,
db=self.db, config_id=config_id,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or "", user_rag_memory_id=user_rag_memory_id or "",
),
self.db,
) )
logger.info(f"Memory write (sync) successful for end_user: {end_user_id}") logger.info(f"Memory write (sync) successful for end_user: {end_user_id}")

View File

@@ -39,6 +39,7 @@ from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
from app.schemas.memory_agent_schema import WriteMemoryRequest
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisFairLock from app.utils.redis_lock import RedisFairLock
@@ -1280,8 +1281,17 @@ def write_message_task(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService() service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, result = await service.write_memory(
user_rag_memory_id, language) WriteMemoryRequest(
end_user_id=end_user_id,
messages=message,
config_id=actual_config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
language=language,
),
db,
)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}") logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result return result