refactor(memory): consolidate write pipeline and rename statement extraction step

- Rename StatementExtractionStep → StatementTemporalExtractionStep and
  extract_statement.jinja2 → extract_statement_temporal.jinja2 to reflect
  merged temporal extraction logic
- Move extraction_pipeline_orchestrator.py out of steps/ to engine root
- Move dedup_step.py into steps/ directory
- Introduce WriteMemoryRequest schema to replace positional args in write_memory()
- Extract _resolve_and_load_config, _preprocess_files, _write_neo4j, and
  _invalidate_interest_cache as private helpers in MemoryAgentService
- Remove shadow pipeline and simplify NEW_PIPELINE_ENABLED branch
- Merge 类型归属/成员隶属/任职服务 relation types into single 归属身份关系 in triplet prompt
- Add alias merge logic (别名属于) in deduplication and MERGE_ALIAS_BELONGS_TO Cypher query
- Add StorageType, Language, MessageItem enums/models to memory_agent_schema
- Reduce AgentMemory_Long_Term.DEFAULT_SCOPE from 6 to 1
- Delete standalone extract_temporal.jinja2 (logic merged into statement step)
This commit is contained in:
lanceyq
2026-04-28 16:36:30 +08:00
parent 7747ed7ac1
commit 1f0c88a5f0
17 changed files with 390 additions and 326 deletions

View File

@@ -12,11 +12,11 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.dedup_step import (
from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
DedupResult,
run_dedup,
)
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (

View File

@@ -227,6 +227,14 @@ class WritePipeline:
f"{time.time() - step_start:.2f}s"
)
# Step 3.2: 别名归并 - 处理 predicate="别名属于" 的关系
step_start = time.time()
await self._merge_alias_belongs_to(extraction_result)
logger.info(
f"[WritePipeline] [3.2] 别名归并:处理别名属于关系 "
f"{time.time() - step_start:.2f}s"
)
# Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在)
await self._extract_emotion(getattr(self, "_emotion_statements", []))
@@ -316,13 +324,13 @@ class WritePipeline:
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
3. run_dedup() → 两阶段去重消歧
"""
from app.core.memory.storage_services.extraction_engine.dedup_step import (
from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
run_dedup,
)
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
build_graph_nodes_and_edges,
)
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator,
)
@@ -554,6 +562,78 @@ class WritePipeline:
else:
raise
# ──────────────────────────────────────────────
# Step 3.2: 别名归并
# ──────────────────────────────────────────────
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None:
"""
所有去重合并都可以使用这个这种统一的处理方式(未实现)
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
- 将 source.name 追加到 target.aliases去重
- 将 source.description append 进 target.description_listlist 形式)
同时在内存中同步更新 ExtractionResult.entity_nodes保持内存与 Neo4j 一致。
失败不中断主流程。
"""
from app.repositories.neo4j.cypher_queries import MERGE_ALIAS_BELONGS_TO
ALIAS_PREDICATE = "别名属于"
# 筛选出所有 predicate="别名属于" 的边
alias_edges = [
e for e in result.entity_entity_edges
if getattr(e, "relation_type", "") == ALIAS_PREDICATE
or getattr(e, "predicate", "") == ALIAS_PREDICATE
]
if not alias_edges:
logger.debug("[AliasMerge] 无 '别名属于' 关系,跳过")
return
try:
# ── 1. 在内存中同步更新 entity_nodes ──
entity_map = {e.id: e for e in result.entity_nodes}
for edge in alias_edges:
source_node = entity_map.get(edge.source)
target_node = entity_map.get(edge.target)
if not source_node or not target_node:
continue
# 将 source.name 追加到 target.aliases去重忽略大小写
source_name = (source_node.name or "").strip()
if source_name:
existing_lower = {a.lower() for a in (target_node.aliases or [])}
if source_name.lower() not in existing_lower:
target_node.aliases = list(target_node.aliases or []) + [source_name]
# 将 source.description append 进 target.description追加分号分隔
src_desc = (source_node.description or "").strip()
if src_desc:
tgt_desc = (target_node.description or "").strip()
if src_desc not in tgt_desc:
target_node.description = f"{tgt_desc}{src_desc}" if tgt_desc else src_desc
logger.info(
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)}'别名属于'"
)
# ── 2. 写入 Neo4j ──
records = await self._neo4j_connector.execute_query(
MERGE_ALIAS_BELONGS_TO,
end_user_id=self.end_user_id,
)
merged_count = len(records) if records else 0
logger.info(
f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录"
)
except Exception as e:
logger.warning(f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True)
# ──────────────────────────────────────────────
# Step 4: 聚类
# ──────────────────────────────────────────────

View File

@@ -1112,6 +1112,39 @@ async def deduplicate_entities_and_edges(
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
# 4) 边重定向与去重
# 4.0 预处理:将 "别名属于" 关系的 source.name/description 归并到 target 节点
# 必须在边重定向之前执行,此时 id_redirect 已包含精确/模糊/LLM 的合并结果
try:
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
for edge in entity_entity_edges:
if getattr(edge, "relation_type", "") != "别名属于":
continue
# 通过 id_redirect 找到合并后的规范节点
source_id = id_redirect.get(edge.source, edge.source)
target_id = id_redirect.get(edge.target, edge.target)
if source_id == target_id:
continue
source_node = entity_by_id.get(source_id)
target_node = entity_by_id.get(target_id)
if not source_node or not target_node:
continue
# 将 source.name 追加到 target.aliases去重忽略大小写
source_name = (source_node.name or "").strip()
if source_name:
existing_lower = {a.lower() for a in (target_node.aliases or [])}
if source_name.lower() not in existing_lower and source_name.lower() != (target_node.name or "").lower():
target_node.aliases = list(target_node.aliases or []) + [source_name]
# 将 source.description 追加到 target.description分号分隔去重
src_desc = (source_node.description or "").strip()
if src_desc:
tgt_desc = (target_node.description or "").strip()
if src_desc not in tgt_desc:
target_node.description = f"{tgt_desc}{src_desc}" if tgt_desc else src_desc
except Exception:
pass
# 4.1 语句→实体边:重复时优先保留 strong
stmt_ent_map: Dict[str, StatementEntityEdge] = {}
for edge in statement_entity_edges:

View File

@@ -23,12 +23,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from .base import ExtractionStep, StepContext
from .embedding_step import EmbeddingStep
from .sidecar_factory import SidecarStepFactory, SidecarTiming
from .statement_step import StatementExtractionStep
from .triplet_step import TripletExtractionStep
from .schema import (
from .steps.base import ExtractionStep, StepContext
from .steps.embedding_step import EmbeddingStep
from .steps.sidecar_factory import SidecarStepFactory, SidecarTiming
from .steps.statement_temporal_step import StatementTemporalExtractionStep
from .steps.triplet_step import TripletExtractionStep
from .steps.schema import (
EmbeddingStepInput,
EmbeddingStepOutput,
EmotionStepInput,
@@ -85,7 +85,7 @@ class NewExtractionOrchestrator:
)
# ── Critical (main-line) steps ──
self.statement_step = StatementExtractionStep(self.context)
self.statement_temporal_step = StatementTemporalExtractionStep(self.context)
self.triplet_step = TripletExtractionStep(
self.context, ontology_types=ontology_types
)
@@ -508,7 +508,7 @@ class NewExtractionOrchestrator:
),
supporting_context=ctx,
)
tasks.append(self.statement_step.run(inp))
tasks.append(self.statement_temporal_step.run(inp))
task_meta.append(
(dialog.id, chunk.id, chunk_speaker, ctx)
)

View File

@@ -7,10 +7,10 @@ for all sidecar (non-critical) steps via SidecarStepFactory.
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
# Step implementations — importing triggers @register self-registration.
from .statement_step import StatementExtractionStep # noqa: F401
from .statement_temporal_step import StatementTemporalExtractionStep # noqa: F401
from .triplet_step import TripletExtractionStep # noqa: F401
from .emotion_step import EmotionExtractionStep # noqa: F401
from .embedding_step import EmbeddingStep # noqa: F401
# Refactored orchestrator
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401

View File

@@ -28,7 +28,7 @@ class SupportingContext(BaseModel):
# ── Statement extraction ──
class StatementStepInput(BaseModel):
"""Input for StatementExtractionStep."""
"""Input for StatementTemporalExtractionStep."""
chunk_id: str
end_user_id: str

View File

@@ -1,4 +1,4 @@
"""StatementExtractionStep — critical step for extracting statements from chunks.
"""StatementTemporalExtractionStep — critical step for extracting statements and temporal info from chunks.
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
@@ -62,8 +62,8 @@ class _StatementExtractionResponse(BaseModel):
return v
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
"""Extract atomic statements (with temporal info) from a dialogue chunk.
class StatementTemporalExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
"""Extract atomic statements with temporal info (valid_at / invalid_at) from a dialogue chunk.
This is a **critical** step failure aborts the pipeline after retries.

View File

@@ -65,7 +65,7 @@ async def render_statement_extraction_prompt(
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_statement.jinja2")
template = prompt_env.get_template("extract_statement_temporal.jinja2")
# Optional clipping of dialogue context
ctx = None
if include_dialogue_context and dialogue_content:
@@ -90,7 +90,7 @@ async def render_statement_extraction_prompt(
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('statement extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_statement.jinja2', {
log_template_rendering('extract_statement_temporal.jinja2', {
'inputs': 'chunk',
'definitions': 'LABEL_DEFINITIONS',
'json_schema': 'StatementExtractionResponse.schema',

View File

@@ -1,126 +0,0 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
{#
This prompt (template) is adapted from [getzep/graphiti]
Licensed under the Apache License, Version 2.0
Original work:
https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py
Modifications made by Ke Sun on 2025-09-01
See the LICENSE file for the full Apache 2.0 license text.
#}
# Task
{% if language == "zh" %}
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
{% else %}
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
{% endif %}
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
{% if inputs %}
{% for key, val in inputs.items() %}
- {{ key }}: {{val}}
{% endfor %}
{% endif %}
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
{% if language == "zh" %}
- **valid_at**: 关系/事件开始或成为真实的时间ISO 8601 格式)
- **invalid_at**: 关系/事件结束或停止为真的时间ISO 8601 格式,如果正在进行则为 null
{% else %}
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
{% endif %}
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
{% if language == "zh" %}
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
2. **使用参考/发布日期作为"现在"** 解释相对时间时
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
4. **对于时间点事件**,仅设置 `valid_at`
{% else %}
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
2. **Use the reference/publication date as "now"** when interpreting relative times
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
4. **For point-in-time events**, set only `valid_at`
{% endif %}
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
{% if language == "zh" %}
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- 如果未指定时间,使用 `00:00:00`(午夜)
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
- 如果仅提及月份,使用月份的第一天或最后一天
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
{% else %}
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- If no time specified, use `00:00:00` (midnight)
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
- If only month mentioned, use first or last day of month
- Always include timezone (use `Z` for UTC if unspecified)
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
{% endif %}
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
{%for key, guide in statement_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
{% if language == "zh" %}
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
{% else %}
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
{% endif %}
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
{% for key, guide in temporal_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
{% if inputs.get('quarter') and inputs.get('publication_date') %}
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
{% if language == "zh" %}
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用Q1、Q2 等)的日期。
{% else %}
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
{% endif %}
{% endif %}
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
{% if language == "zh" %}
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
2. 使用反斜杠转义内部引号: `\"`
3. JSON 字符串值中不要有换行符
4. 正确关闭并用逗号分隔所有字段
{% else %}
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
2. Escape internal quotes with backslash: `\"`
3. No line breaks within JSON string values
4. Properly close and comma-separate all fields
{% endif %}
## {% if language == "zh" %}语言{% else %}Language{% endif %}
{% if language == "zh" %}
输出语言必须与输入语言匹配。
{% else %}
Output language must match input language.
{% endif %}
{{ json_schema }}

View File

@@ -250,28 +250,12 @@ Primary statement to analyze:
- notes: 只处理名字性表达,不处理角色、职业、评价词。
- status: `enabled`
- `类型归属关系`
- definition: 表达实体属于某种类别,或主体承担某种角色/职业身份的关系。
- covered_predicates: `属于类型`、`担任角色`、`从事职业`
- positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`
- negative_examples: `张三 -> 担任角色 -> 山哥`、`用户 -> 从事职业 -> 紧张`
- notes: 用于“是什么”,不用于“叫什么”
- status: `enabled`
- `成员隶属关系`
- definition: 表达主体属于某个组织、群体或集合的成员归属关系。
- covered_predicates: `成员属于`
- positive_examples: `张三 -> 成员属于 -> 实验室成员`、`用户 -> 成员属于 -> 社群`
- negative_examples: `他们 -> 成员属于 -> 学校`、`一个朋友 -> 成员属于 -> 班级`
- notes: 前提是主体和归属对象都足够稳定;边界不稳的人群不要硬抽。
- status: `enabled`
- `任职服务关系`
- definition: 表达人物或主体在组织中的工作、任职或服务关系。
- covered_predicates: `任职于`
- positive_examples: `张明 -> 任职于 -> 腾讯`、`王教授 -> 任职于 -> 清华大学`
- negative_examples: `张明 -> 任职于 -> 导师`、`用户 -> 任职于 -> 明天的面试`
- notes: 优先用于人物到组织的稳定供职关系。
- `归属身份关系`
- definition: 表达主体所属的类别、身份、职业、角色,或其与组织、群体、集合之间的归属关系。
- covered_predicates: `属于类型`、`担任角色`、`从事职业`、`成员属于`、`任职于`
- positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`、`张三 -> 成员属于 -> 实验室成员`、`张明 -> 任职于 -> 腾讯`
- negative_examples: `张三 -> 担任角色 -> 山哥`、`他们 -> 成员属于 -> 学校`、`用户 -> 任职于 -> 明天的面试`、`用户 -> 从事职业 -> 紧张`
- notes: 这是一个上位父类,用于统一承接“是什么身份”与“归属哪里”两类关系。第一层不再强行区分“身份类归属”和“组织类归属”,真正的区分在子类 predicate 层完成
- status: `enabled`
- `空间位置关系`

View File

@@ -1107,6 +1107,35 @@ RETURN (
) AS is_complete
"""
# 别名归并:将 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边的 source.name
# 合并进 target.aliases去重并将 source.description 追加到 target.description分号分隔
MERGE_ALIAS_BELONGS_TO = """
MATCH (source:ExtractedEntity {end_user_id: $end_user_id})-[r:EXTRACTED_RELATIONSHIP]->(target:ExtractedEntity {end_user_id: $end_user_id})
WHERE r.predicate = '别名属于'
WITH source, target,
coalesce(target.aliases, []) AS existing_aliases,
source.name AS source_name,
coalesce(source.description, '') AS src_desc,
coalesce(target.description, '') AS tgt_desc
// 1. 合并 aliases将 source.name 追加到 target.aliases去重
WITH source, target, src_desc, tgt_desc,
CASE
WHEN source_name IS NOT NULL AND source_name <> '' AND NOT source_name IN existing_aliases
THEN existing_aliases + source_name
ELSE existing_aliases
END AS new_aliases
SET target.aliases = new_aliases,
target.description = CASE
WHEN src_desc <> '' AND NOT src_desc IN tgt_desc
THEN CASE WHEN tgt_desc = '' THEN src_desc ELSE tgt_desc + '' + src_desc END
ELSE tgt_desc
END
RETURN source.name AS merged_alias, target.name AS target_name, new_aliases AS updated_aliases
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (

View File

@@ -1,9 +1,32 @@
from abc import ABC
from typing import Optional
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
class StorageType(str, Enum):
"""记忆存储后端类型"""
NEO4J = "neo4j"
RAG = "rag"
class Language(str, Enum): # 没有传递到聚类的celery任务中去任务会回退失败用默认值考虑统一语言问题
"""支持的语言"""
ZH = "zh"
EN = "en"
class MessageItem(BaseModel):
"""单条消息结构"""
role: str
content: str
files: Optional[list[dict]] = None
file_content: Optional[list[Any]] = None
model_config = {"extra": "allow"}
class UserInput(BaseModel):
message: str
history: list[dict]
@@ -18,6 +41,16 @@ class Write_UserInput(BaseModel):
config_id: Optional[str] = None
class WriteMemoryRequest(BaseModel):
"""write_memory() 的参数封装"""
end_user_id: str
messages: list[MessageItem]
config_id: Optional[Any] = None
storage_type: StorageType = StorageType.NEO4J
user_rag_memory_id: str = ""
language: Language = Language.ZH
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
@@ -25,7 +58,7 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6
DEFAULT_SCOPE = 1
TIME_SCOPE = 5

View File

@@ -36,13 +36,14 @@ from app.core.memory.agent.utils.messages_tools import (
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.memory_service import MemoryService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_agent_schema import Language, MessageItem, StorageType, Write_UserInput, WriteMemoryRequest
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
@@ -267,25 +268,15 @@ class MemoryAgentService:
async def write_memory(
self,
end_user_id: str,
messages: list[dict],
config_id: Optional[uuid.UUID] | int,
request: WriteMemoryRequest,
db: Session,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> str:
"""
Process write operation with config_id
长期记忆写入
Args:
end_user_id: Group identifier (also used as end_user_id)
messages: Message to write
config_id: Configuration ID from database
request: 写入请求参数end_user_id、messages、config_id、storage_type、language 等)
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
language: 语言类型 ("zh" 中文, "en" 英文)
Returns:
Write operation result status
@@ -293,147 +284,50 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get("workspace_id")
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
import time
end_user_id = request.end_user_id
messages = request.messages
config_id = request.config_id
storage_type = request.storage_type
user_rag_memory_id = request.user_rag_memory_id
language = request.language
start_time = time.time()
# Load configuration from database with workspace fallback
# Use a separate database session to avoid transaction failures
# ── Step 1: 解析配置 ── 通过 end_user_id 查找关联的 config_id / workspace_id并从数据库加载完整 memory_config
memory_config = await self._resolve_and_load_config(
end_user_id, config_id, db, start_time
)
# ── Step 2: 文件预处理 ── 将消息中附带的文件转换为感知记忆对象,挂载到 message["file_content"]
messages = await self._preprocess_files(messages, end_user_id, memory_config, db)
message_text = "\n".join([
f"{(msg['role'] if isinstance(msg, dict) else msg.role)}: {(msg['content'] if isinstance(msg, dict) else msg.content)}"
for msg in messages
])
# ── Step 3: 写入存储 ── 根据 storage_type 分流到 RAG 或 Neo4j 流水线
try:
from app.db import get_db_context
with get_db_context() as config_db:
config_service = MemoryConfigService(config_db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
for message in messages:
message["file_content"] = []
for file in (message.get("files") or []):
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
if file_object is None:
continue
message["file_content"].append((file_object, file["type"]))
logger.info(messages)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
if storage_type == StorageType.RAG:
await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success"
else:
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码
import os
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
await self._write_neo4j(end_user_id, messages, memory_config, language)
if use_new_pipeline:
# ── 新流水线WritePipeline + NewExtractionOrchestrator ──
from app.core.memory.memory_service import MemoryService
service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
result = await service.write(
messages=messages,
language=language,
ref_id='',
is_pilot_run=False,
)
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
# ── 旧流水线write_tools.write() + ExtractionOrchestrator ──
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
memory_config=memory_config,
ref_id='',
language=language
)
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
try:
from app.core.memory.memory_service import MemoryService
import copy
shadow_messages = copy.deepcopy(messages)
shadow_service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
shadow_result = await shadow_service.write(
messages=shadow_messages,
language=language,
ref_id='',
is_pilot_run=True,
)
logger.info(
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
f"extraction={shadow_result.extraction}"
)
except Exception as shadow_err:
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
# ── 影子运行结束 ──
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang
)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
# ── Step 4: 后处理 ── 失效缓存、序列化文件路径、记录审计日志并返回结果
await self._invalidate_interest_cache(end_user_id)
for message in messages:
message["file_content"] = [
perceptual[0].file_path for perceptual in message["file_content"]
]
if isinstance(message, dict):
message["file_content"] = [
perceptual[0].file_path for perceptual in (message["file_content"] or [])
]
else:
message.file_content = [
perceptual[0].file_path for perceptual in (message.file_content or [])
]
return self.writer_messages_deal(
"success",
start_time,
end_user_id,
config_id,
memory_config.config_id,
message_text,
{
"status": "success",
@@ -443,15 +337,139 @@ class MemoryAgentService:
}
)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
audit_logger.log_operation(
operation="WRITE",
config_id=memory_config.config_id,
end_user_id=end_user_id,
success=False,
duration=time.time() - start_time,
error=error_msg,
)
raise ValueError(error_msg)
async def _resolve_and_load_config(
self,
end_user_id: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
start_time: float,
):
"""解析 end_user 关联配置并从数据库加载完整 memory_config。"""
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get("workspace_id")
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration."
)
except Exception as e:
if "No memory configuration found" in str(e):
raise
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
try:
with get_db_context() as config_db:
memory_config = MemoryConfigService(config_db).load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService",
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
return memory_config
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=time.time() - start_time,
error=error_msg,
)
raise ValueError(error_msg)
async def _preprocess_files(
self,
messages: list[MessageItem] | list[dict],
end_user_id: str,
memory_config,
db: Session,
) -> list[dict]:
"""处理消息中附带的文件,生成感知记忆对象并挂载到 message['file_content']。"""
perceptual_service = MemoryPerceptualService(db)
for message in messages:
if isinstance(message, dict):
message["file_content"] = []
files = message.get("files") or []
else:
message.file_content = []
files = message.files or []
for file in files:
file_object = await perceptual_service.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file),
)
if file_object is None:
continue
if isinstance(message, dict):
message["file_content"].append((file_object, file["type"]))
else:
message.file_content.append((file_object, file["type"]))
logger.info(messages)
return messages
async def _write_neo4j(
self,
end_user_id: str,
messages: list[MessageItem] | list[dict],
memory_config,
language: Language | str,
) -> None:
"""根据 NEW_PIPELINE_ENABLED 选择新旧流水线写入 Neo4j。"""
# 统一转换为 dict下游流水线期望 list[dict]
messages_dict = [
msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True)
for msg in messages
]
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
if use_new_pipeline:
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
result = await service.write(messages=messages_dict, language=language, ref_id='')
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
await write_neo4j(
end_user_id=end_user_id,
messages=messages_dict,
memory_config=memory_config,
ref_id='',
language=language,
)
async def _invalidate_interest_cache(self, end_user_id: str) -> None:
"""写入完成后失效兴趣分布缓存。"""
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(end_user_id, lang)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}"
)
async def read_memory(
self,
end_user_id: str,

View File

@@ -17,6 +17,7 @@ from app.core.logging_config import get_logger
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.schemas.memory_config_schema import ConfigurationError
from app.schemas.memory_agent_schema import WriteMemoryRequest
from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__)
@@ -291,12 +292,14 @@ class MemoryAPIService:
try:
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
result = await MemoryAgentService().write_memory(
end_user_id=end_user_id,
messages=messages,
config_id=config_id,
db=self.db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or "",
WriteMemoryRequest(
end_user_id=end_user_id,
messages=messages,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or "",
),
self.db,
)
logger.info(f"Memory write (sync) successful for end_user: {end_user_id}")

View File

@@ -39,6 +39,7 @@ from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
from app.schemas.memory_agent_schema import WriteMemoryRequest
from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisFairLock
@@ -1280,8 +1281,17 @@ def write_message_task(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
user_rag_memory_id, language)
result = await service.write_memory(
WriteMemoryRequest(
end_user_id=end_user_id,
messages=message,
config_id=actual_config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
language=language,
),
db,
)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result