From 1f0c88a5f0e9fd1267c38128e89feb83948af754 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 28 Apr 2026 16:36:30 +0800 Subject: [PATCH] refactor(memory): consolidate write pipeline and rename statement extraction step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename StatementExtractionStep → StatementTemporalExtractionStep and extract_statement.jinja2 → extract_statement_temporal.jinja2 to reflect merged temporal extraction logic - Move extraction_pipeline_orchestrator.py out of steps/ to engine root - Move dedup_step.py into steps/ directory - Introduce WriteMemoryRequest schema to replace positional args in write_memory() - Extract _resolve_and_load_config, _preprocess_files, _write_neo4j, and _invalidate_interest_cache as private helpers in MemoryAgentService - Remove shadow pipeline and simplify NEW_PIPELINE_ENABLED branch - Merge 类型归属/成员隶属/任职服务 relation types into single 归属身份关系 in triplet prompt - Add alias merge logic (别名属于) in deduplication and MERGE_ALIAS_BELONGS_TO Cypher query - Add StorageType, Language, MessageItem enums/models to memory_agent_schema - Reduce AgentMemory_Long_Term.DEFAULT_SCOPE from 6 to 1 - Delete standalone extract_temporal.jinja2 (logic merged into statement step) --- .../memory/pipelines/pilot_write_pipeline.py | 4 +- .../core/memory/pipelines/write_pipeline.py | 84 ++++- .../deduplication/deduped_and_disamb.py | 33 ++ .../extraction_pipeline_orchestrator.py | 16 +- .../extraction_engine/steps/__init__.py | 4 +- .../{ => steps}/dedup_step.py | 0 .../steps/schema/extraction_step_schema.py | 2 +- ...ent_step.py => statement_temporal_step.py} | 6 +- .../core/memory/utils/prompt/prompt_utils.py | 4 +- ...nja2 => extract_statement_temporal.jinja2} | 0 .../prompt/prompts/extract_temporal.jinja2 | 126 ------- .../prompt/prompts/extract_triplet.jinja2 | 28 +- api/app/repositories/neo4j/cypher_queries.py | 29 ++ api/app/schemas/memory_agent_schema.py | 37 ++- api/app/services/memory_agent_service.py | 314 +++++++++--------- api/app/services/memory_api_service.py | 15 +- api/app/tasks.py | 14 +- 17 files changed, 390 insertions(+), 326 deletions(-) rename api/app/core/memory/storage_services/extraction_engine/{steps => }/extraction_pipeline_orchestrator.py (98%) rename api/app/core/memory/storage_services/extraction_engine/{ => steps}/dedup_step.py (100%) rename api/app/core/memory/storage_services/extraction_engine/steps/{statement_step.py => statement_temporal_step.py} (95%) rename api/app/core/memory/utils/prompt/prompts/{extract_statement.jinja2 => extract_statement_temporal.jinja2} (100%) delete mode 100644 api/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 diff --git a/api/app/core/memory/pipelines/pilot_write_pipeline.py b/api/app/core/memory/pipelines/pilot_write_pipeline.py index 0465b66e..4c9e1750 100644 --- a/api/app/core/memory/pipelines/pilot_write_pipeline.py +++ b/api/app/core/memory/pipelines/pilot_write_pipeline.py @@ -12,11 +12,11 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional from app.core.memory.models.message_models import DialogData from app.core.memory.models.variate_config import ExtractionPipelineConfig -from app.core.memory.storage_services.extraction_engine.dedup_step import ( +from app.core.memory.storage_services.extraction_engine.steps.dedup_step import ( DedupResult, run_dedup, ) -from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( +from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import ( NewExtractionOrchestrator, ) from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( diff --git a/api/app/core/memory/pipelines/write_pipeline.py b/api/app/core/memory/pipelines/write_pipeline.py index 9883f42a..a6927847 100644 --- a/api/app/core/memory/pipelines/write_pipeline.py +++ b/api/app/core/memory/pipelines/write_pipeline.py @@ -227,6 +227,14 @@ class WritePipeline: f"✔ {time.time() - step_start:.2f}s" ) + # Step 3.2: 别名归并 - 处理 predicate="别名属于" 的关系 + step_start = time.time() + await self._merge_alias_belongs_to(extraction_result) + logger.info( + f"[WritePipeline] [3.2] 别名归并:处理别名属于关系 " + f"✔ {time.time() - step_start:.2f}s" + ) + # Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在) await self._extract_emotion(getattr(self, "_emotion_statements", [])) @@ -316,13 +324,13 @@ class WritePipeline: 2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边 3. run_dedup() → 两阶段去重消歧 """ - from app.core.memory.storage_services.extraction_engine.dedup_step import ( + from app.core.memory.storage_services.extraction_engine.steps.dedup_step import ( run_dedup, ) from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( build_graph_nodes_and_edges, ) - from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( + from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import ( NewExtractionOrchestrator, ) @@ -554,6 +562,78 @@ class WritePipeline: else: raise + # ────────────────────────────────────────────── + # Step 3.2: 别名归并 + # ────────────────────────────────────────────── + + async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None: + """ + 所有去重合并都可以使用这个这种统一的处理方式(未实现) + 别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。 + + 对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边: + - 将 source.name 追加到 target.aliases(去重) + - 将 source.description append 进 target.description_list(list 形式) + + 同时在内存中同步更新 ExtractionResult.entity_nodes,保持内存与 Neo4j 一致。 + 失败不中断主流程。 + """ + from app.repositories.neo4j.cypher_queries import MERGE_ALIAS_BELONGS_TO + + ALIAS_PREDICATE = "别名属于" + + # 筛选出所有 predicate="别名属于" 的边 + alias_edges = [ + e for e in result.entity_entity_edges + if getattr(e, "relation_type", "") == ALIAS_PREDICATE + or getattr(e, "predicate", "") == ALIAS_PREDICATE + ] + + if not alias_edges: + logger.debug("[AliasMerge] 无 '别名属于' 关系,跳过") + return + + try: + # ── 1. 在内存中同步更新 entity_nodes ── + entity_map = {e.id: e for e in result.entity_nodes} + + for edge in alias_edges: + source_node = entity_map.get(edge.source) + target_node = entity_map.get(edge.target) + if not source_node or not target_node: + continue + + # 将 source.name 追加到 target.aliases(去重,忽略大小写) + source_name = (source_node.name or "").strip() + if source_name: + existing_lower = {a.lower() for a in (target_node.aliases or [])} + if source_name.lower() not in existing_lower: + target_node.aliases = list(target_node.aliases or []) + [source_name] + + # 将 source.description append 进 target.description(追加,分号分隔) + src_desc = (source_node.description or "").strip() + if src_desc: + tgt_desc = (target_node.description or "").strip() + if src_desc not in tgt_desc: + target_node.description = f"{tgt_desc};{src_desc}" if tgt_desc else src_desc + + logger.info( + f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)} 条 '别名属于' 边" + ) + + # ── 2. 写入 Neo4j ── + records = await self._neo4j_connector.execute_query( + MERGE_ALIAS_BELONGS_TO, + end_user_id=self.end_user_id, + ) + merged_count = len(records) if records else 0 + logger.info( + f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录" + ) + + except Exception as e: + logger.warning(f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True) + # ────────────────────────────────────────────── # Step 4: 聚类 # ────────────────────────────────────────────── diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 715f190c..980f5130 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -1112,6 +1112,39 @@ async def deduplicate_entities_and_edges( # 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方 # 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID # 4) 边重定向与去重 + # 4.0 预处理:将 "别名属于" 关系的 source.name/description 归并到 target 节点 + # 必须在边重定向之前执行,此时 id_redirect 已包含精确/模糊/LLM 的合并结果 + try: + entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities} + for edge in entity_entity_edges: + if getattr(edge, "relation_type", "") != "别名属于": + continue + # 通过 id_redirect 找到合并后的规范节点 + source_id = id_redirect.get(edge.source, edge.source) + target_id = id_redirect.get(edge.target, edge.target) + if source_id == target_id: + continue + source_node = entity_by_id.get(source_id) + target_node = entity_by_id.get(target_id) + if not source_node or not target_node: + continue + + # 将 source.name 追加到 target.aliases(去重,忽略大小写) + source_name = (source_node.name or "").strip() + if source_name: + existing_lower = {a.lower() for a in (target_node.aliases or [])} + if source_name.lower() not in existing_lower and source_name.lower() != (target_node.name or "").lower(): + target_node.aliases = list(target_node.aliases or []) + [source_name] + + # 将 source.description 追加到 target.description(分号分隔,去重) + src_desc = (source_node.description or "").strip() + if src_desc: + tgt_desc = (target_node.description or "").strip() + if src_desc not in tgt_desc: + target_node.description = f"{tgt_desc};{src_desc}" if tgt_desc else src_desc + except Exception: + pass + # 4.1 语句→实体边:重复时优先保留 strong stmt_ent_map: Dict[str, StatementEntityEdge] = {} for edge in statement_entity_edges: diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_pipeline_orchestrator.py similarity index 98% rename from api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py rename to api/app/core/memory/storage_services/extraction_engine/extraction_pipeline_orchestrator.py index 72d7901f..5c158083 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_pipeline_orchestrator.py @@ -23,12 +23,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from app.core.memory.models.message_models import DialogData from app.core.memory.models.variate_config import ExtractionPipelineConfig -from .base import ExtractionStep, StepContext -from .embedding_step import EmbeddingStep -from .sidecar_factory import SidecarStepFactory, SidecarTiming -from .statement_step import StatementExtractionStep -from .triplet_step import TripletExtractionStep -from .schema import ( +from .steps.base import ExtractionStep, StepContext +from .steps.embedding_step import EmbeddingStep +from .steps.sidecar_factory import SidecarStepFactory, SidecarTiming +from .steps.statement_temporal_step import StatementTemporalExtractionStep +from .steps.triplet_step import TripletExtractionStep +from .steps.schema import ( EmbeddingStepInput, EmbeddingStepOutput, EmotionStepInput, @@ -85,7 +85,7 @@ class NewExtractionOrchestrator: ) # ── Critical (main-line) steps ── - self.statement_step = StatementExtractionStep(self.context) + self.statement_temporal_step = StatementTemporalExtractionStep(self.context) self.triplet_step = TripletExtractionStep( self.context, ontology_types=ontology_types ) @@ -508,7 +508,7 @@ class NewExtractionOrchestrator: ), supporting_context=ctx, ) - tasks.append(self.statement_step.run(inp)) + tasks.append(self.statement_temporal_step.run(inp)) task_meta.append( (dialog.id, chunk.id, chunk_speaker, ctx) ) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py b/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py index 63a8ec77..cbd7b742 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py @@ -7,10 +7,10 @@ for all sidecar (non-critical) steps via SidecarStepFactory. from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401 # Step implementations — importing triggers @register self-registration. -from .statement_step import StatementExtractionStep # noqa: F401 +from .statement_temporal_step import StatementTemporalExtractionStep # noqa: F401 from .triplet_step import TripletExtractionStep # noqa: F401 from .emotion_step import EmotionExtractionStep # noqa: F401 from .embedding_step import EmbeddingStep # noqa: F401 # Refactored orchestrator -from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401 +from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401 diff --git a/api/app/core/memory/storage_services/extraction_engine/dedup_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/dedup_step.py similarity index 100% rename from api/app/core/memory/storage_services/extraction_engine/dedup_step.py rename to api/app/core/memory/storage_services/extraction_engine/steps/dedup_step.py diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py index eacab0b6..498dec54 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py @@ -28,7 +28,7 @@ class SupportingContext(BaseModel): # ── Statement extraction ── class StatementStepInput(BaseModel): - """Input for StatementExtractionStep.""" + """Input for StatementTemporalExtractionStep.""" chunk_id: str end_user_id: str diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/statement_temporal_step.py similarity index 95% rename from api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py rename to api/app/core/memory/storage_services/extraction_engine/steps/statement_temporal_step.py index 25f13e24..7c0e3a48 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/statement_temporal_step.py @@ -1,4 +1,4 @@ -"""StatementExtractionStep — critical step for extracting statements from chunks. +"""StatementTemporalExtractionStep — critical step for extracting statements and temporal info from chunks. Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm. Temporal extraction logic (valid_at / invalid_at) is merged into this step, @@ -62,8 +62,8 @@ class _StatementExtractionResponse(BaseModel): return v -class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]): - """Extract atomic statements (with temporal info) from a dialogue chunk. +class StatementTemporalExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]): + """Extract atomic statements with temporal info (valid_at / invalid_at) from a dialogue chunk. This is a **critical** step — failure aborts the pipeline after retries. diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index bcc11c0d..43926d83 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -65,7 +65,7 @@ async def render_statement_extraction_prompt( Returns: Rendered prompt content as string """ - template = prompt_env.get_template("extract_statement.jinja2") + template = prompt_env.get_template("extract_statement_temporal.jinja2") # Optional clipping of dialogue context ctx = None if include_dialogue_context and dialogue_content: @@ -90,7 +90,7 @@ async def render_statement_extraction_prompt( # 记录渲染结果到提示日志(与示例日志结构一致) log_prompt_rendering('statement extraction', rendered_prompt) # 可选:记录模板渲染信息 - log_template_rendering('extract_statement.jinja2', { + log_template_rendering('extract_statement_temporal.jinja2', { 'inputs': 'chunk', 'definitions': 'LABEL_DEFINITIONS', 'json_schema': 'StatementExtractionResponse.schema', diff --git a/api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_statement_temporal.jinja2 similarity index 100% rename from api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 rename to api/app/core/memory/utils/prompt/prompts/extract_statement_temporal.jinja2 diff --git a/api/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 deleted file mode 100644 index 00a0374d..00000000 --- a/api/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 +++ /dev/null @@ -1,126 +0,0 @@ - -{% macro tidy(name) -%} - {{ name.replace('_', ' ')}} -{%- endmacro %} -{# - This prompt (template) is adapted from [getzep/graphiti] - Licensed under the Apache License, Version 2.0 - - Original work: - https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py - - Modifications made by Ke Sun on 2025-09-01 - See the LICENSE file for the full Apache 2.0 license text. -#} -# Task - -{% if language == "zh" %} -从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。 -{% else %} -Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable). -{% endif %} - -# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %} -{% if inputs %} -{% for key, val in inputs.items() %} -- {{ key }}: {{val}} -{% endfor %} -{% endif %} - -# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %} - -{% if language == "zh" %} -- **valid_at**: 关系/事件开始或成为真实的时间(ISO 8601 格式) -- **invalid_at**: 关系/事件结束或停止为真的时间(ISO 8601 格式,如果正在进行则为 null) -{% else %} -- **valid_at**: When the relationship/event started or became true (ISO 8601 format) -- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing) -{% endif %} - -# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %} - -## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %} -{% if language == "zh" %} -1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期 -2. **使用参考/发布日期作为"现在"** 解释相对时间时 -3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及 -4. **对于时间点事件**,仅设置 `valid_at` -{% else %} -1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge -2. **Use the reference/publication date as "now"** when interpreting relative times -3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions -4. **For point-in-time events**, set only `valid_at` -{% endif %} - -## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %} -{% if language == "zh" %} -- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ` -- 如果未指定时间,使用 `00:00:00`(午夜) -- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束) -- 如果仅提及月份,使用月份的第一天或最后一天 -- 始终包含时区(如果未指定,使用 `Z` 表示 UTC) -- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期 -{% else %} -- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ` -- If no time specified, use `00:00:00` (midnight) -- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate -- If only month mentioned, use first or last day of month -- Always include timezone (use `Z` for UTC if unspecified) -- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date -{% endif %} - -## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %} - -{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}: -{%for key, guide in statement_guide.items() %} -- {{ tidy(key) | capitalize }}: {{ guide }} -{% endfor %} - -**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:** -{% if language == "zh" %} -- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间) -- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束 -{% else %} -- **Opinion statements**: Set only `valid_at` (when opinion was expressed) -- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned -{% endif %} - -## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %} - -{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}: -{% for key, guide in temporal_guide.items() %} -- {{ tidy(key) | capitalize }}: {{ guide }} -{% endfor %} - -{% if inputs.get('quarter') and inputs.get('publication_date') %} -## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %} -{% if language == "zh" %} -假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用(Q1、Q2 等)的日期。 -{% else %} -Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline. -{% endif %} -{% endif %} - -# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %} - -## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %} -{% if language == "zh" %} -1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体 -2. 使用反斜杠转义内部引号: `\"` -3. JSON 字符串值中不要有换行符 -4. 正确关闭并用逗号分隔所有字段 -{% else %} -1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants -2. Escape internal quotes with backslash: `\"` -3. No line breaks within JSON string values -4. Properly close and comma-separate all fields -{% endif %} - -## {% if language == "zh" %}语言{% else %}Language{% endif %} -{% if language == "zh" %} -输出语言必须与输入语言匹配。 -{% else %} -Output language must match input language. -{% endif %} - -{{ json_schema }} diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index bc1cf7ac..fa868104 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -250,28 +250,12 @@ Primary statement to analyze: - notes: 只处理名字性表达,不处理角色、职业、评价词。 - status: `enabled` -- `类型归属关系` - - definition: 表达实体属于某种类别,或主体承担某种角色/职业身份的关系。 - - covered_predicates: `属于类型`、`担任角色`、`从事职业` - - positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员` - - negative_examples: `张三 -> 担任角色 -> 山哥`、`用户 -> 从事职业 -> 紧张` - - notes: 用于“是什么”,不用于“叫什么”。 - - status: `enabled` - -- `成员隶属关系` - - definition: 表达主体属于某个组织、群体或集合的成员归属关系。 - - covered_predicates: `成员属于` - - positive_examples: `张三 -> 成员属于 -> 实验室成员`、`用户 -> 成员属于 -> 社群` - - negative_examples: `他们 -> 成员属于 -> 学校`、`一个朋友 -> 成员属于 -> 班级` - - notes: 前提是主体和归属对象都足够稳定;边界不稳的人群不要硬抽。 - - status: `enabled` - -- `任职服务关系` - - definition: 表达人物或主体在组织中的工作、任职或服务关系。 - - covered_predicates: `任职于` - - positive_examples: `张明 -> 任职于 -> 腾讯`、`王教授 -> 任职于 -> 清华大学` - - negative_examples: `张明 -> 任职于 -> 导师`、`用户 -> 任职于 -> 明天的面试` - - notes: 优先用于人物到组织的稳定供职关系。 +- `归属身份关系` + - definition: 表达主体所属的类别、身份、职业、角色,或其与组织、群体、集合之间的归属关系。 + - covered_predicates: `属于类型`、`担任角色`、`从事职业`、`成员属于`、`任职于` + - positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`、`张三 -> 成员属于 -> 实验室成员`、`张明 -> 任职于 -> 腾讯` + - negative_examples: `张三 -> 担任角色 -> 山哥`、`他们 -> 成员属于 -> 学校`、`用户 -> 任职于 -> 明天的面试`、`用户 -> 从事职业 -> 紧张` + - notes: 这是一个上位父类,用于统一承接“是什么身份”与“归属哪里”两类关系。第一层不再强行区分“身份类归属”和“组织类归属”,真正的区分在子类 predicate 层完成。 - status: `enabled` - `空间位置关系` diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index b0d18482..03d7ed7a 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1107,6 +1107,35 @@ RETURN ( ) AS is_complete """ +# 别名归并:将 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边的 source.name +# 合并进 target.aliases(去重),并将 source.description 追加到 target.description(分号分隔) +MERGE_ALIAS_BELONGS_TO = """ +MATCH (source:ExtractedEntity {end_user_id: $end_user_id})-[r:EXTRACTED_RELATIONSHIP]->(target:ExtractedEntity {end_user_id: $end_user_id}) +WHERE r.predicate = '别名属于' +WITH source, target, + coalesce(target.aliases, []) AS existing_aliases, + source.name AS source_name, + coalesce(source.description, '') AS src_desc, + coalesce(target.description, '') AS tgt_desc + +// 1. 合并 aliases:将 source.name 追加到 target.aliases(去重) +WITH source, target, src_desc, tgt_desc, + CASE + WHEN source_name IS NOT NULL AND source_name <> '' AND NOT source_name IN existing_aliases + THEN existing_aliases + source_name + ELSE existing_aliases + END AS new_aliases + +SET target.aliases = new_aliases, + target.description = CASE + WHEN src_desc <> '' AND NOT src_desc IN tgt_desc + THEN CASE WHEN tgt_desc = '' THEN src_desc ELSE tgt_desc + ';' + src_desc END + ELSE tgt_desc + END + +RETURN source.name AS merged_alias, target.name AS target_name, new_aliases AS updated_aliases +""" + CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) RETURN ( diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 97aa5bb5..3f77f219 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,9 +1,32 @@ from abc import ABC -from typing import Optional +from enum import Enum +from typing import Any, Optional from pydantic import BaseModel +class StorageType(str, Enum): + """记忆存储后端类型""" + NEO4J = "neo4j" + RAG = "rag" + + +class Language(str, Enum): # 没有传递到聚类的celery任务中去,任务会回退失败用默认值,考虑统一语言问题 + """支持的语言""" + ZH = "zh" + EN = "en" + + +class MessageItem(BaseModel): + """单条消息结构""" + role: str + content: str + files: Optional[list[dict]] = None + file_content: Optional[list[Any]] = None + + model_config = {"extra": "allow"} + + class UserInput(BaseModel): message: str history: list[dict] @@ -18,6 +41,16 @@ class Write_UserInput(BaseModel): config_id: Optional[str] = None +class WriteMemoryRequest(BaseModel): + """write_memory() 的参数封装""" + end_user_id: str + messages: list[MessageItem] + config_id: Optional[Any] = None + storage_type: StorageType = StorageType.NEO4J + user_rag_memory_id: str = "" + language: Language = Language.ZH + + class AgentMemory_Long_Term(ABC): """长期记忆配置常量""" STORAGE_NEO4J = "neo4j" @@ -25,7 +58,7 @@ class AgentMemory_Long_Term(ABC): STRATEGY_AGGREGATE = "aggregate" STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" - DEFAULT_SCOPE = 6 + DEFAULT_SCOPE = 1 TIME_SCOPE = 5 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 9f4875ed..ca207933 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,13 +36,14 @@ from app.core.memory.agent.utils.messages_tools import ( from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution +from app.core.memory.memory_service import MemoryService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.audit_logger import audit_logger from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas import FileInput -from app.schemas.memory_agent_schema import Write_UserInput +from app.schemas.memory_agent_schema import Language, MessageItem, StorageType, Write_UserInput, WriteMemoryRequest from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( @@ -267,25 +268,15 @@ class MemoryAgentService: async def write_memory( self, - end_user_id: str, - messages: list[dict], - config_id: Optional[uuid.UUID] | int, + request: WriteMemoryRequest, db: Session, - storage_type: str, - user_rag_memory_id: str, - language: str = "zh" ) -> str: """ - Process write operation with config_id + 长期记忆写入 Args: - end_user_id: Group identifier (also used as end_user_id) - messages: Message to write - config_id: Configuration ID from database + request: 写入请求参数(end_user_id、messages、config_id、storage_type、language 等) db: SQLAlchemy database session - storage_type: Storage type (neo4j or rag) - user_rag_memory_id: User RAG memory ID - language: 语言类型 ("zh" 中文, "en" 英文) Returns: Write operation result status @@ -293,147 +284,50 @@ class MemoryAgentService: Raises: ValueError: If config loading fails or write operation fails """ - # Resolve config_id and workspace_id - # Always get workspace_id from end_user for fallback, even if config_id is provided - workspace_id = None - try: - connected_config = get_end_user_connected_config(end_user_id, db) - workspace_id = connected_config.get("workspace_id") - if config_id is None: - config_id = connected_config.get("memory_config_id") - logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") - if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. " - f"Please ensure the user has a connected memory configuration.") - except Exception as e: - if "No memory configuration found" in str(e): - raise # Re-raise our specific error - logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") - if config_id is None: - raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") - # If config_id was provided, continue without workspace_id fallback - - import time + end_user_id = request.end_user_id + messages = request.messages + config_id = request.config_id + storage_type = request.storage_type + user_rag_memory_id = request.user_rag_memory_id + language = request.language start_time = time.time() - # Load configuration from database with workspace fallback - # Use a separate database session to avoid transaction failures + # ── Step 1: 解析配置 ── 通过 end_user_id 查找关联的 config_id / workspace_id,并从数据库加载完整 memory_config + memory_config = await self._resolve_and_load_config( + end_user_id, config_id, db, start_time + ) + + # ── Step 2: 文件预处理 ── 将消息中附带的文件转换为感知记忆对象,挂载到 message["file_content"] + messages = await self._preprocess_files(messages, end_user_id, memory_config, db) + message_text = "\n".join([ + f"{(msg['role'] if isinstance(msg, dict) else msg.role)}: {(msg['content'] if isinstance(msg, dict) else msg.content)}" + for msg in messages + ]) + + # ── Step 3: 写入存储 ── 根据 storage_type 分流到 RAG 或 Neo4j 流水线 try: - from app.db import get_db_context - with get_db_context() as config_db: - config_service = MemoryConfigService(config_db) - memory_config = config_service.load_memory_config( - config_id=config_id, - workspace_id=workspace_id, - service_name="MemoryAgentService" - ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") - except ConfigurationError as e: - error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" - logger.error(error_msg) - - # Log failed operation - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) - - raise ValueError(error_msg) - - perceptual_serivce = MemoryPerceptualService(db) - for message in messages: - message["file_content"] = [] - for file in (message.get("files") or []): - file_object = await perceptual_serivce.generate_perceptual_memory( - end_user_id=end_user_id, - memory_config=memory_config, - file=FileInput(**file) - ) - if file_object is None: - continue - message["file_content"].append((file_object, file["type"])) - logger.info(messages) - - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - try: - if storage_type == "rag": - # For RAG storage, convert messages to single string + if storage_type == StorageType.RAG: await write_rag(end_user_id, message_text, user_rag_memory_id) return "success" else: - # TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码 - import os - use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true" + await self._write_neo4j(end_user_id, messages, memory_config, language) - if use_new_pipeline: - # ── 新流水线:WritePipeline + NewExtractionOrchestrator ── - from app.core.memory.memory_service import MemoryService - - service = MemoryService( - memory_config=memory_config, - end_user_id=end_user_id, - ) - result = await service.write( - messages=messages, - language=language, - ref_id='', - is_pilot_run=False, - ) - logger.info( - f"[NewPipeline] 完成: status={result.status}, " - f"elapsed={result.elapsed_seconds:.2f}s, " - f"extraction={result.extraction}" - ) - else: - # ── 旧流水线:write_tools.write() + ExtractionOrchestrator ── - await write_neo4j( - end_user_id=end_user_id, - messages=messages, - memory_config=memory_config, - ref_id='', - language=language - ) - - # ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ── - if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true": - try: - from app.core.memory.memory_service import MemoryService - import copy - - shadow_messages = copy.deepcopy(messages) - shadow_service = MemoryService( - memory_config=memory_config, - end_user_id=end_user_id, - ) - shadow_result = await shadow_service.write( - messages=shadow_messages, - language=language, - ref_id='', - is_pilot_run=True, - ) - logger.info( - f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, " - f"elapsed={shadow_result.elapsed_seconds:.2f}s, " - f"extraction={shadow_result.extraction}" - ) - except Exception as shadow_err: - logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}") - # ── 影子运行结束 ── - for lang in ["zh", "en"]: - deleted = await InterestMemoryCache.delete_interest_distribution( - end_user_id, lang - ) - if deleted: - logger.info( - f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + # ── Step 4: 后处理 ── 失效缓存、序列化文件路径、记录审计日志并返回结果 + await self._invalidate_interest_cache(end_user_id) for message in messages: - message["file_content"] = [ - perceptual[0].file_path for perceptual in message["file_content"] - ] + if isinstance(message, dict): + message["file_content"] = [ + perceptual[0].file_path for perceptual in (message["file_content"] or []) + ] + else: + message.file_content = [ + perceptual[0].file_path for perceptual in (message.file_content or []) + ] return self.writer_messages_deal( "success", start_time, end_user_id, - config_id, + memory_config.config_id, message_text, { "status": "success", @@ -443,15 +337,139 @@ class MemoryAgentService: } ) except Exception as e: - # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) - - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + audit_logger.log_operation( + operation="WRITE", + config_id=memory_config.config_id, + end_user_id=end_user_id, + success=False, + duration=time.time() - start_time, + error=error_msg, + ) raise ValueError(error_msg) + async def _resolve_and_load_config( + self, + end_user_id: str, + config_id: Optional[uuid.UUID] | int, + db: Session, + start_time: float, + ): + """解析 end_user 关联配置并从数据库加载完整 memory_config。""" + workspace_id = None + try: + connected_config = get_end_user_connected_config(end_user_id, db) + workspace_id = connected_config.get("workspace_id") + if config_id is None: + config_id = connected_config.get("memory_config_id") + logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") + if config_id is None and workspace_id is None: + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. " + f"Please ensure the user has a connected memory configuration." + ) + except Exception as e: + if "No memory configuration found" in str(e): + raise + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + if config_id is None: + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") + + try: + with get_db_context() as config_db: + memory_config = MemoryConfigService(config_db).load_memory_config( + config_id=config_id, + workspace_id=workspace_id, + service_name="MemoryAgentService", + ) + logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + return memory_config + except ConfigurationError as e: + error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" + logger.error(error_msg) + audit_logger.log_operation( + operation="WRITE", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=time.time() - start_time, + error=error_msg, + ) + raise ValueError(error_msg) + + async def _preprocess_files( + self, + messages: list[MessageItem] | list[dict], + end_user_id: str, + memory_config, + db: Session, + ) -> list[dict]: + """处理消息中附带的文件,生成感知记忆对象并挂载到 message['file_content']。""" + perceptual_service = MemoryPerceptualService(db) + for message in messages: + if isinstance(message, dict): + message["file_content"] = [] + files = message.get("files") or [] + else: + message.file_content = [] + files = message.files or [] + for file in files: + file_object = await perceptual_service.generate_perceptual_memory( + end_user_id=end_user_id, + memory_config=memory_config, + file=FileInput(**file), + ) + if file_object is None: + continue + if isinstance(message, dict): + message["file_content"].append((file_object, file["type"])) + else: + message.file_content.append((file_object, file["type"])) + logger.info(messages) + return messages + + async def _write_neo4j( + self, + end_user_id: str, + messages: list[MessageItem] | list[dict], + memory_config, + language: Language | str, + ) -> None: + """根据 NEW_PIPELINE_ENABLED 选择新旧流水线写入 Neo4j。""" + # 统一转换为 dict,下游流水线期望 list[dict] + messages_dict = [ + msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True) + for msg in messages + ] + use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true" + + if use_new_pipeline: + service = MemoryService(memory_config=memory_config, end_user_id=end_user_id) + result = await service.write(messages=messages_dict, language=language, ref_id='') + logger.info( + f"[NewPipeline] 完成: status={result.status}, " + f"elapsed={result.elapsed_seconds:.2f}s, " + f"extraction={result.extraction}" + ) + else: + await write_neo4j( + end_user_id=end_user_id, + messages=messages_dict, + memory_config=memory_config, + ref_id='', + language=language, + ) + + async def _invalidate_interest_cache(self, end_user_id: str) -> None: + """写入完成后失效兴趣分布缓存。""" + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution(end_user_id, lang) + if deleted: + logger.info( + f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}" + ) + async def read_memory( self, end_user_id: str, diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 82d1c463..e078d400 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -17,6 +17,7 @@ from app.core.logging_config import get_logger from app.models.app_model import App from app.models.end_user_model import EndUser from app.schemas.memory_config_schema import ConfigurationError +from app.schemas.memory_agent_schema import WriteMemoryRequest from app.services.memory_agent_service import MemoryAgentService logger = get_logger(__name__) @@ -291,12 +292,14 @@ class MemoryAPIService: try: messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( - end_user_id=end_user_id, - messages=messages, - config_id=config_id, - db=self.db, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id or "", + WriteMemoryRequest( + end_user_id=end_user_id, + messages=messages, + config_id=config_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id or "", + ), + self.db, ) logger.info(f"Memory write (sync) successful for end_user: {end_user_id}") diff --git a/api/app/tasks.py b/api/app/tasks.py index 54ebe80f..f697b5f3 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -39,6 +39,7 @@ from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config +from app.schemas.memory_agent_schema import WriteMemoryRequest from app.services.memory_forget_service import MemoryForgetService from app.utils.config_utils import resolve_config_id from app.utils.redis_lock import RedisFairLock @@ -1280,8 +1281,17 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, - user_rag_memory_id, language) + result = await service.write_memory( + WriteMemoryRequest( + end_user_id=end_user_id, + messages=message, + config_id=actual_config_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + language=language, + ), + db, + ) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result