refactor(memory): consolidate write pipeline and rename statement extraction step
- Rename StatementExtractionStep → StatementTemporalExtractionStep and extract_statement.jinja2 → extract_statement_temporal.jinja2 to reflect merged temporal extraction logic - Move extraction_pipeline_orchestrator.py out of steps/ to engine root - Move dedup_step.py into steps/ directory - Introduce WriteMemoryRequest schema to replace positional args in write_memory() - Extract _resolve_and_load_config, _preprocess_files, _write_neo4j, and _invalidate_interest_cache as private helpers in MemoryAgentService - Remove shadow pipeline and simplify NEW_PIPELINE_ENABLED branch - Merge 类型归属/成员隶属/任职服务 relation types into single 归属身份关系 in triplet prompt - Add alias merge logic (别名属于) in deduplication and MERGE_ALIAS_BELONGS_TO Cypher query - Add StorageType, Language, MessageItem enums/models to memory_agent_schema - Reduce AgentMemory_Long_Term.DEFAULT_SCOPE from 6 to 1 - Delete standalone extract_temporal.jinja2 (logic merged into statement step)
This commit is contained in:
@@ -12,11 +12,11 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||||
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
|
||||||
DedupResult,
|
DedupResult,
|
||||||
run_dedup,
|
run_dedup,
|
||||||
)
|
)
|
||||||
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
|
||||||
NewExtractionOrchestrator,
|
NewExtractionOrchestrator,
|
||||||
)
|
)
|
||||||
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||||
|
|||||||
@@ -227,6 +227,14 @@ class WritePipeline:
|
|||||||
f"✔ {time.time() - step_start:.2f}s"
|
f"✔ {time.time() - step_start:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step 3.2: 别名归并 - 处理 predicate="别名属于" 的关系
|
||||||
|
step_start = time.time()
|
||||||
|
await self._merge_alias_belongs_to(extraction_result)
|
||||||
|
logger.info(
|
||||||
|
f"[WritePipeline] [3.2] 别名归并:处理别名属于关系 "
|
||||||
|
f"✔ {time.time() - step_start:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
||||||
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||||
|
|
||||||
@@ -316,13 +324,13 @@ class WritePipeline:
|
|||||||
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
|
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
|
||||||
3. run_dedup() → 两阶段去重消歧
|
3. run_dedup() → 两阶段去重消歧
|
||||||
"""
|
"""
|
||||||
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
from app.core.memory.storage_services.extraction_engine.steps.dedup_step import (
|
||||||
run_dedup,
|
run_dedup,
|
||||||
)
|
)
|
||||||
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||||
build_graph_nodes_and_edges,
|
build_graph_nodes_and_edges,
|
||||||
)
|
)
|
||||||
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import (
|
||||||
NewExtractionOrchestrator,
|
NewExtractionOrchestrator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -554,6 +562,78 @@ class WritePipeline:
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Step 3.2: 别名归并
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None:
|
||||||
|
"""
|
||||||
|
所有去重合并都可以使用这个这种统一的处理方式(未实现)
|
||||||
|
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
|
||||||
|
|
||||||
|
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
|
||||||
|
- 将 source.name 追加到 target.aliases(去重)
|
||||||
|
- 将 source.description append 进 target.description_list(list 形式)
|
||||||
|
|
||||||
|
同时在内存中同步更新 ExtractionResult.entity_nodes,保持内存与 Neo4j 一致。
|
||||||
|
失败不中断主流程。
|
||||||
|
"""
|
||||||
|
from app.repositories.neo4j.cypher_queries import MERGE_ALIAS_BELONGS_TO
|
||||||
|
|
||||||
|
ALIAS_PREDICATE = "别名属于"
|
||||||
|
|
||||||
|
# 筛选出所有 predicate="别名属于" 的边
|
||||||
|
alias_edges = [
|
||||||
|
e for e in result.entity_entity_edges
|
||||||
|
if getattr(e, "relation_type", "") == ALIAS_PREDICATE
|
||||||
|
or getattr(e, "predicate", "") == ALIAS_PREDICATE
|
||||||
|
]
|
||||||
|
|
||||||
|
if not alias_edges:
|
||||||
|
logger.debug("[AliasMerge] 无 '别名属于' 关系,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── 1. 在内存中同步更新 entity_nodes ──
|
||||||
|
entity_map = {e.id: e for e in result.entity_nodes}
|
||||||
|
|
||||||
|
for edge in alias_edges:
|
||||||
|
source_node = entity_map.get(edge.source)
|
||||||
|
target_node = entity_map.get(edge.target)
|
||||||
|
if not source_node or not target_node:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 将 source.name 追加到 target.aliases(去重,忽略大小写)
|
||||||
|
source_name = (source_node.name or "").strip()
|
||||||
|
if source_name:
|
||||||
|
existing_lower = {a.lower() for a in (target_node.aliases or [])}
|
||||||
|
if source_name.lower() not in existing_lower:
|
||||||
|
target_node.aliases = list(target_node.aliases or []) + [source_name]
|
||||||
|
|
||||||
|
# 将 source.description append 进 target.description(追加,分号分隔)
|
||||||
|
src_desc = (source_node.description or "").strip()
|
||||||
|
if src_desc:
|
||||||
|
tgt_desc = (target_node.description or "").strip()
|
||||||
|
if src_desc not in tgt_desc:
|
||||||
|
target_node.description = f"{tgt_desc};{src_desc}" if tgt_desc else src_desc
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)} 条 '别名属于' 边"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 2. 写入 Neo4j ──
|
||||||
|
records = await self._neo4j_connector.execute_query(
|
||||||
|
MERGE_ALIAS_BELONGS_TO,
|
||||||
|
end_user_id=self.end_user_id,
|
||||||
|
)
|
||||||
|
merged_count = len(records) if records else 0
|
||||||
|
logger.info(
|
||||||
|
f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 4: 聚类
|
# Step 4: 聚类
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
|||||||
@@ -1112,6 +1112,39 @@ async def deduplicate_entities_and_edges(
|
|||||||
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
|
# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方
|
||||||
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
|
# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID
|
||||||
# 4) 边重定向与去重
|
# 4) 边重定向与去重
|
||||||
|
# 4.0 预处理:将 "别名属于" 关系的 source.name/description 归并到 target 节点
|
||||||
|
# 必须在边重定向之前执行,此时 id_redirect 已包含精确/模糊/LLM 的合并结果
|
||||||
|
try:
|
||||||
|
entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities}
|
||||||
|
for edge in entity_entity_edges:
|
||||||
|
if getattr(edge, "relation_type", "") != "别名属于":
|
||||||
|
continue
|
||||||
|
# 通过 id_redirect 找到合并后的规范节点
|
||||||
|
source_id = id_redirect.get(edge.source, edge.source)
|
||||||
|
target_id = id_redirect.get(edge.target, edge.target)
|
||||||
|
if source_id == target_id:
|
||||||
|
continue
|
||||||
|
source_node = entity_by_id.get(source_id)
|
||||||
|
target_node = entity_by_id.get(target_id)
|
||||||
|
if not source_node or not target_node:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 将 source.name 追加到 target.aliases(去重,忽略大小写)
|
||||||
|
source_name = (source_node.name or "").strip()
|
||||||
|
if source_name:
|
||||||
|
existing_lower = {a.lower() for a in (target_node.aliases or [])}
|
||||||
|
if source_name.lower() not in existing_lower and source_name.lower() != (target_node.name or "").lower():
|
||||||
|
target_node.aliases = list(target_node.aliases or []) + [source_name]
|
||||||
|
|
||||||
|
# 将 source.description 追加到 target.description(分号分隔,去重)
|
||||||
|
src_desc = (source_node.description or "").strip()
|
||||||
|
if src_desc:
|
||||||
|
tgt_desc = (target_node.description or "").strip()
|
||||||
|
if src_desc not in tgt_desc:
|
||||||
|
target_node.description = f"{tgt_desc};{src_desc}" if tgt_desc else src_desc
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# 4.1 语句→实体边:重复时优先保留 strong
|
# 4.1 语句→实体边:重复时优先保留 strong
|
||||||
stmt_ent_map: Dict[str, StatementEntityEdge] = {}
|
stmt_ent_map: Dict[str, StatementEntityEdge] = {}
|
||||||
for edge in statement_entity_edges:
|
for edge in statement_entity_edges:
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
|||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||||
|
|
||||||
from .base import ExtractionStep, StepContext
|
from .steps.base import ExtractionStep, StepContext
|
||||||
from .embedding_step import EmbeddingStep
|
from .steps.embedding_step import EmbeddingStep
|
||||||
from .sidecar_factory import SidecarStepFactory, SidecarTiming
|
from .steps.sidecar_factory import SidecarStepFactory, SidecarTiming
|
||||||
from .statement_step import StatementExtractionStep
|
from .steps.statement_temporal_step import StatementTemporalExtractionStep
|
||||||
from .triplet_step import TripletExtractionStep
|
from .steps.triplet_step import TripletExtractionStep
|
||||||
from .schema import (
|
from .steps.schema import (
|
||||||
EmbeddingStepInput,
|
EmbeddingStepInput,
|
||||||
EmbeddingStepOutput,
|
EmbeddingStepOutput,
|
||||||
EmotionStepInput,
|
EmotionStepInput,
|
||||||
@@ -85,7 +85,7 @@ class NewExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ── Critical (main-line) steps ──
|
# ── Critical (main-line) steps ──
|
||||||
self.statement_step = StatementExtractionStep(self.context)
|
self.statement_temporal_step = StatementTemporalExtractionStep(self.context)
|
||||||
self.triplet_step = TripletExtractionStep(
|
self.triplet_step = TripletExtractionStep(
|
||||||
self.context, ontology_types=ontology_types
|
self.context, ontology_types=ontology_types
|
||||||
)
|
)
|
||||||
@@ -508,7 +508,7 @@ class NewExtractionOrchestrator:
|
|||||||
),
|
),
|
||||||
supporting_context=ctx,
|
supporting_context=ctx,
|
||||||
)
|
)
|
||||||
tasks.append(self.statement_step.run(inp))
|
tasks.append(self.statement_temporal_step.run(inp))
|
||||||
task_meta.append(
|
task_meta.append(
|
||||||
(dialog.id, chunk.id, chunk_speaker, ctx)
|
(dialog.id, chunk.id, chunk_speaker, ctx)
|
||||||
)
|
)
|
||||||
@@ -7,10 +7,10 @@ for all sidecar (non-critical) steps via SidecarStepFactory.
|
|||||||
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
|
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
|
||||||
|
|
||||||
# Step implementations — importing triggers @register self-registration.
|
# Step implementations — importing triggers @register self-registration.
|
||||||
from .statement_step import StatementExtractionStep # noqa: F401
|
from .statement_temporal_step import StatementTemporalExtractionStep # noqa: F401
|
||||||
from .triplet_step import TripletExtractionStep # noqa: F401
|
from .triplet_step import TripletExtractionStep # noqa: F401
|
||||||
from .emotion_step import EmotionExtractionStep # noqa: F401
|
from .emotion_step import EmotionExtractionStep # noqa: F401
|
||||||
from .embedding_step import EmbeddingStep # noqa: F401
|
from .embedding_step import EmbeddingStep # noqa: F401
|
||||||
|
|
||||||
# Refactored orchestrator
|
# Refactored orchestrator
|
||||||
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401
|
from app.core.memory.storage_services.extraction_engine.extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class SupportingContext(BaseModel):
|
|||||||
|
|
||||||
# ── Statement extraction ──
|
# ── Statement extraction ──
|
||||||
class StatementStepInput(BaseModel):
|
class StatementStepInput(BaseModel):
|
||||||
"""Input for StatementExtractionStep."""
|
"""Input for StatementTemporalExtractionStep."""
|
||||||
|
|
||||||
chunk_id: str
|
chunk_id: str
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""StatementExtractionStep — critical step for extracting statements from chunks.
|
"""StatementTemporalExtractionStep — critical step for extracting statements and temporal info from chunks.
|
||||||
|
|
||||||
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
|
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
|
||||||
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
|
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
|
||||||
@@ -62,8 +62,8 @@ class _StatementExtractionResponse(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
|
class StatementTemporalExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
|
||||||
"""Extract atomic statements (with temporal info) from a dialogue chunk.
|
"""Extract atomic statements with temporal info (valid_at / invalid_at) from a dialogue chunk.
|
||||||
|
|
||||||
This is a **critical** step — failure aborts the pipeline after retries.
|
This is a **critical** step — failure aborts the pipeline after retries.
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ async def render_statement_extraction_prompt(
|
|||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
"""
|
"""
|
||||||
template = prompt_env.get_template("extract_statement.jinja2")
|
template = prompt_env.get_template("extract_statement_temporal.jinja2")
|
||||||
# Optional clipping of dialogue context
|
# Optional clipping of dialogue context
|
||||||
ctx = None
|
ctx = None
|
||||||
if include_dialogue_context and dialogue_content:
|
if include_dialogue_context and dialogue_content:
|
||||||
@@ -90,7 +90,7 @@ async def render_statement_extraction_prompt(
|
|||||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||||
log_prompt_rendering('statement extraction', rendered_prompt)
|
log_prompt_rendering('statement extraction', rendered_prompt)
|
||||||
# 可选:记录模板渲染信息
|
# 可选:记录模板渲染信息
|
||||||
log_template_rendering('extract_statement.jinja2', {
|
log_template_rendering('extract_statement_temporal.jinja2', {
|
||||||
'inputs': 'chunk',
|
'inputs': 'chunk',
|
||||||
'definitions': 'LABEL_DEFINITIONS',
|
'definitions': 'LABEL_DEFINITIONS',
|
||||||
'json_schema': 'StatementExtractionResponse.schema',
|
'json_schema': 'StatementExtractionResponse.schema',
|
||||||
|
|||||||
@@ -1,126 +0,0 @@
|
|||||||
|
|
||||||
{% macro tidy(name) -%}
|
|
||||||
{{ name.replace('_', ' ')}}
|
|
||||||
{%- endmacro %}
|
|
||||||
{#
|
|
||||||
This prompt (template) is adapted from [getzep/graphiti]
|
|
||||||
Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
Original work:
|
|
||||||
https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py
|
|
||||||
|
|
||||||
Modifications made by Ke Sun on 2025-09-01
|
|
||||||
See the LICENSE file for the full Apache 2.0 license text.
|
|
||||||
#}
|
|
||||||
# Task
|
|
||||||
|
|
||||||
{% if language == "zh" %}
|
|
||||||
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
|
|
||||||
{% else %}
|
|
||||||
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
|
|
||||||
{% if inputs %}
|
|
||||||
{% for key, val in inputs.items() %}
|
|
||||||
- {{ key }}: {{val}}
|
|
||||||
{% endfor %}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
|
|
||||||
|
|
||||||
{% if language == "zh" %}
|
|
||||||
- **valid_at**: 关系/事件开始或成为真实的时间(ISO 8601 格式)
|
|
||||||
- **invalid_at**: 关系/事件结束或停止为真的时间(ISO 8601 格式,如果正在进行则为 null)
|
|
||||||
{% else %}
|
|
||||||
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
|
|
||||||
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
|
|
||||||
{% if language == "zh" %}
|
|
||||||
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
|
|
||||||
2. **使用参考/发布日期作为"现在"** 解释相对时间时
|
|
||||||
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
|
|
||||||
4. **对于时间点事件**,仅设置 `valid_at`
|
|
||||||
{% else %}
|
|
||||||
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
|
|
||||||
2. **Use the reference/publication date as "now"** when interpreting relative times
|
|
||||||
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
|
|
||||||
4. **For point-in-time events**, set only `valid_at`
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
|
|
||||||
{% if language == "zh" %}
|
|
||||||
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
|
||||||
- 如果未指定时间,使用 `00:00:00`(午夜)
|
|
||||||
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
|
|
||||||
- 如果仅提及月份,使用月份的第一天或最后一天
|
|
||||||
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC)
|
|
||||||
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
|
|
||||||
{% else %}
|
|
||||||
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
|
||||||
- If no time specified, use `00:00:00` (midnight)
|
|
||||||
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
|
|
||||||
- If only month mentioned, use first or last day of month
|
|
||||||
- Always include timezone (use `Z` for UTC if unspecified)
|
|
||||||
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
|
|
||||||
|
|
||||||
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
|
|
||||||
{%for key, guide in statement_guide.items() %}
|
|
||||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
|
|
||||||
{% if language == "zh" %}
|
|
||||||
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
|
|
||||||
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
|
|
||||||
{% else %}
|
|
||||||
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
|
|
||||||
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
|
|
||||||
|
|
||||||
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
|
|
||||||
{% for key, guide in temporal_guide.items() %}
|
|
||||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
{% if inputs.get('quarter') and inputs.get('publication_date') %}
|
|
||||||
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
|
|
||||||
{% if language == "zh" %}
|
|
||||||
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用(Q1、Q2 等)的日期。
|
|
||||||
{% else %}
|
|
||||||
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
|
|
||||||
{% endif %}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
|
|
||||||
{% if language == "zh" %}
|
|
||||||
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
|
|
||||||
2. 使用反斜杠转义内部引号: `\"`
|
|
||||||
3. JSON 字符串值中不要有换行符
|
|
||||||
4. 正确关闭并用逗号分隔所有字段
|
|
||||||
{% else %}
|
|
||||||
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
|
|
||||||
2. Escape internal quotes with backslash: `\"`
|
|
||||||
3. No line breaks within JSON string values
|
|
||||||
4. Properly close and comma-separate all fields
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
## {% if language == "zh" %}语言{% else %}Language{% endif %}
|
|
||||||
{% if language == "zh" %}
|
|
||||||
输出语言必须与输入语言匹配。
|
|
||||||
{% else %}
|
|
||||||
Output language must match input language.
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{{ json_schema }}
|
|
||||||
@@ -250,28 +250,12 @@ Primary statement to analyze:
|
|||||||
- notes: 只处理名字性表达,不处理角色、职业、评价词。
|
- notes: 只处理名字性表达,不处理角色、职业、评价词。
|
||||||
- status: `enabled`
|
- status: `enabled`
|
||||||
|
|
||||||
- `类型归属关系`
|
- `归属身份关系`
|
||||||
- definition: 表达实体属于某种类别,或主体承担某种角色/职业身份的关系。
|
- definition: 表达主体所属的类别、身份、职业、角色,或其与组织、群体、集合之间的归属关系。
|
||||||
- covered_predicates: `属于类型`、`担任角色`、`从事职业`
|
- covered_predicates: `属于类型`、`担任角色`、`从事职业`、`成员属于`、`任职于`
|
||||||
- positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`
|
- positive_examples: `王教授 -> 担任角色 -> 导师`、`张三 -> 从事职业 -> 程序员`、`张三 -> 成员属于 -> 实验室成员`、`张明 -> 任职于 -> 腾讯`
|
||||||
- negative_examples: `张三 -> 担任角色 -> 山哥`、`用户 -> 从事职业 -> 紧张`
|
- negative_examples: `张三 -> 担任角色 -> 山哥`、`他们 -> 成员属于 -> 学校`、`用户 -> 任职于 -> 明天的面试`、`用户 -> 从事职业 -> 紧张`
|
||||||
- notes: 用于“是什么”,不用于“叫什么”。
|
- notes: 这是一个上位父类,用于统一承接“是什么身份”与“归属哪里”两类关系。第一层不再强行区分“身份类归属”和“组织类归属”,真正的区分在子类 predicate 层完成。
|
||||||
- status: `enabled`
|
|
||||||
|
|
||||||
- `成员隶属关系`
|
|
||||||
- definition: 表达主体属于某个组织、群体或集合的成员归属关系。
|
|
||||||
- covered_predicates: `成员属于`
|
|
||||||
- positive_examples: `张三 -> 成员属于 -> 实验室成员`、`用户 -> 成员属于 -> 社群`
|
|
||||||
- negative_examples: `他们 -> 成员属于 -> 学校`、`一个朋友 -> 成员属于 -> 班级`
|
|
||||||
- notes: 前提是主体和归属对象都足够稳定;边界不稳的人群不要硬抽。
|
|
||||||
- status: `enabled`
|
|
||||||
|
|
||||||
- `任职服务关系`
|
|
||||||
- definition: 表达人物或主体在组织中的工作、任职或服务关系。
|
|
||||||
- covered_predicates: `任职于`
|
|
||||||
- positive_examples: `张明 -> 任职于 -> 腾讯`、`王教授 -> 任职于 -> 清华大学`
|
|
||||||
- negative_examples: `张明 -> 任职于 -> 导师`、`用户 -> 任职于 -> 明天的面试`
|
|
||||||
- notes: 优先用于人物到组织的稳定供职关系。
|
|
||||||
- status: `enabled`
|
- status: `enabled`
|
||||||
|
|
||||||
- `空间位置关系`
|
- `空间位置关系`
|
||||||
|
|||||||
@@ -1107,6 +1107,35 @@ RETURN (
|
|||||||
) AS is_complete
|
) AS is_complete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 别名归并:将 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边的 source.name
|
||||||
|
# 合并进 target.aliases(去重),并将 source.description 追加到 target.description(分号分隔)
|
||||||
|
MERGE_ALIAS_BELONGS_TO = """
|
||||||
|
MATCH (source:ExtractedEntity {end_user_id: $end_user_id})-[r:EXTRACTED_RELATIONSHIP]->(target:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE r.predicate = '别名属于'
|
||||||
|
WITH source, target,
|
||||||
|
coalesce(target.aliases, []) AS existing_aliases,
|
||||||
|
source.name AS source_name,
|
||||||
|
coalesce(source.description, '') AS src_desc,
|
||||||
|
coalesce(target.description, '') AS tgt_desc
|
||||||
|
|
||||||
|
// 1. 合并 aliases:将 source.name 追加到 target.aliases(去重)
|
||||||
|
WITH source, target, src_desc, tgt_desc,
|
||||||
|
CASE
|
||||||
|
WHEN source_name IS NOT NULL AND source_name <> '' AND NOT source_name IN existing_aliases
|
||||||
|
THEN existing_aliases + source_name
|
||||||
|
ELSE existing_aliases
|
||||||
|
END AS new_aliases
|
||||||
|
|
||||||
|
SET target.aliases = new_aliases,
|
||||||
|
target.description = CASE
|
||||||
|
WHEN src_desc <> '' AND NOT src_desc IN tgt_desc
|
||||||
|
THEN CASE WHEN tgt_desc = '' THEN src_desc ELSE tgt_desc + ';' + src_desc END
|
||||||
|
ELSE tgt_desc
|
||||||
|
END
|
||||||
|
|
||||||
|
RETURN source.name AS merged_alias, target.name AS target_name, new_aliases AS updated_aliases
|
||||||
|
"""
|
||||||
|
|
||||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
RETURN (
|
RETURN (
|
||||||
|
|||||||
@@ -1,9 +1,32 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Optional
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class StorageType(str, Enum):
|
||||||
|
"""记忆存储后端类型"""
|
||||||
|
NEO4J = "neo4j"
|
||||||
|
RAG = "rag"
|
||||||
|
|
||||||
|
|
||||||
|
class Language(str, Enum): # 没有传递到聚类的celery任务中去,任务会回退失败用默认值,考虑统一语言问题
|
||||||
|
"""支持的语言"""
|
||||||
|
ZH = "zh"
|
||||||
|
EN = "en"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageItem(BaseModel):
|
||||||
|
"""单条消息结构"""
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
files: Optional[list[dict]] = None
|
||||||
|
file_content: Optional[list[Any]] = None
|
||||||
|
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
|
||||||
class UserInput(BaseModel):
|
class UserInput(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
history: list[dict]
|
history: list[dict]
|
||||||
@@ -18,6 +41,16 @@ class Write_UserInput(BaseModel):
|
|||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WriteMemoryRequest(BaseModel):
|
||||||
|
"""write_memory() 的参数封装"""
|
||||||
|
end_user_id: str
|
||||||
|
messages: list[MessageItem]
|
||||||
|
config_id: Optional[Any] = None
|
||||||
|
storage_type: StorageType = StorageType.NEO4J
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
language: Language = Language.ZH
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory_Long_Term(ABC):
|
class AgentMemory_Long_Term(ABC):
|
||||||
"""长期记忆配置常量"""
|
"""长期记忆配置常量"""
|
||||||
STORAGE_NEO4J = "neo4j"
|
STORAGE_NEO4J = "neo4j"
|
||||||
@@ -25,7 +58,7 @@ class AgentMemory_Long_Term(ABC):
|
|||||||
STRATEGY_AGGREGATE = "aggregate"
|
STRATEGY_AGGREGATE = "aggregate"
|
||||||
STRATEGY_CHUNK = "chunk"
|
STRATEGY_CHUNK = "chunk"
|
||||||
STRATEGY_TIME = "time"
|
STRATEGY_TIME = "time"
|
||||||
DEFAULT_SCOPE = 6
|
DEFAULT_SCOPE = 1
|
||||||
TIME_SCOPE = 5
|
TIME_SCOPE = 5
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,13 +36,14 @@ from app.core.memory.agent.utils.messages_tools import (
|
|||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
|
from app.core.memory.memory_service import MemoryService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Language, MessageItem, StorageType, Write_UserInput, WriteMemoryRequest
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
@@ -267,25 +268,15 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
request: WriteMemoryRequest,
|
||||||
messages: list[dict],
|
|
||||||
config_id: Optional[uuid.UUID] | int,
|
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
|
||||||
user_rag_memory_id: str,
|
|
||||||
language: str = "zh"
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
长期记忆写入
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group identifier (also used as end_user_id)
|
request: 写入请求参数(end_user_id、messages、config_id、storage_type、language 等)
|
||||||
messages: Message to write
|
|
||||||
config_id: Configuration ID from database
|
|
||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
storage_type: Storage type (neo4j or rag)
|
|
||||||
user_rag_memory_id: User RAG memory ID
|
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Write operation result status
|
Write operation result status
|
||||||
@@ -293,147 +284,50 @@ class MemoryAgentService:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If config loading fails or write operation fails
|
ValueError: If config loading fails or write operation fails
|
||||||
"""
|
"""
|
||||||
# Resolve config_id and workspace_id
|
end_user_id = request.end_user_id
|
||||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
messages = request.messages
|
||||||
workspace_id = None
|
config_id = request.config_id
|
||||||
try:
|
storage_type = request.storage_type
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
user_rag_memory_id = request.user_rag_memory_id
|
||||||
workspace_id = connected_config.get("workspace_id")
|
language = request.language
|
||||||
if config_id is None:
|
|
||||||
config_id = connected_config.get("memory_config_id")
|
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
|
||||||
if config_id is None and workspace_id is None:
|
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
|
|
||||||
f"Please ensure the user has a connected memory configuration.")
|
|
||||||
except Exception as e:
|
|
||||||
if "No memory configuration found" in str(e):
|
|
||||||
raise # Re-raise our specific error
|
|
||||||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
if config_id is None:
|
|
||||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
|
||||||
# If config_id was provided, continue without workspace_id fallback
|
|
||||||
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Load configuration from database with workspace fallback
|
# ── Step 1: 解析配置 ── 通过 end_user_id 查找关联的 config_id / workspace_id,并从数据库加载完整 memory_config
|
||||||
# Use a separate database session to avoid transaction failures
|
memory_config = await self._resolve_and_load_config(
|
||||||
|
end_user_id, config_id, db, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Step 2: 文件预处理 ── 将消息中附带的文件转换为感知记忆对象,挂载到 message["file_content"]
|
||||||
|
messages = await self._preprocess_files(messages, end_user_id, memory_config, db)
|
||||||
|
message_text = "\n".join([
|
||||||
|
f"{(msg['role'] if isinstance(msg, dict) else msg.role)}: {(msg['content'] if isinstance(msg, dict) else msg.content)}"
|
||||||
|
for msg in messages
|
||||||
|
])
|
||||||
|
|
||||||
|
# ── Step 3: 写入存储 ── 根据 storage_type 分流到 RAG 或 Neo4j 流水线
|
||||||
try:
|
try:
|
||||||
from app.db import get_db_context
|
if storage_type == StorageType.RAG:
|
||||||
with get_db_context() as config_db:
|
|
||||||
config_service = MemoryConfigService(config_db)
|
|
||||||
memory_config = config_service.load_memory_config(
|
|
||||||
config_id=config_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
service_name="MemoryAgentService"
|
|
||||||
)
|
|
||||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
|
||||||
except ConfigurationError as e:
|
|
||||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
|
|
||||||
# Log failed operation
|
|
||||||
duration = time.time() - start_time
|
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
|
||||||
success=False, duration=duration, error=error_msg)
|
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
perceptual_serivce = MemoryPerceptualService(db)
|
|
||||||
for message in messages:
|
|
||||||
message["file_content"] = []
|
|
||||||
for file in (message.get("files") or []):
|
|
||||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
memory_config=memory_config,
|
|
||||||
file=FileInput(**file)
|
|
||||||
)
|
|
||||||
if file_object is None:
|
|
||||||
continue
|
|
||||||
message["file_content"].append((file_object, file["type"]))
|
|
||||||
logger.info(messages)
|
|
||||||
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
try:
|
|
||||||
if storage_type == "rag":
|
|
||||||
# For RAG storage, convert messages to single string
|
|
||||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
return "success"
|
return "success"
|
||||||
else:
|
else:
|
||||||
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码
|
await self._write_neo4j(end_user_id, messages, memory_config, language)
|
||||||
import os
|
|
||||||
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
|
|
||||||
|
|
||||||
if use_new_pipeline:
|
# ── Step 4: 后处理 ── 失效缓存、序列化文件路径、记录审计日志并返回结果
|
||||||
# ── 新流水线:WritePipeline + NewExtractionOrchestrator ──
|
await self._invalidate_interest_cache(end_user_id)
|
||||||
from app.core.memory.memory_service import MemoryService
|
|
||||||
|
|
||||||
service = MemoryService(
|
|
||||||
memory_config=memory_config,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
|
||||||
result = await service.write(
|
|
||||||
messages=messages,
|
|
||||||
language=language,
|
|
||||||
ref_id='',
|
|
||||||
is_pilot_run=False,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[NewPipeline] 完成: status={result.status}, "
|
|
||||||
f"elapsed={result.elapsed_seconds:.2f}s, "
|
|
||||||
f"extraction={result.extraction}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# ── 旧流水线:write_tools.write() + ExtractionOrchestrator ──
|
|
||||||
await write_neo4j(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
messages=messages,
|
|
||||||
memory_config=memory_config,
|
|
||||||
ref_id='',
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
|
|
||||||
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
|
|
||||||
try:
|
|
||||||
from app.core.memory.memory_service import MemoryService
|
|
||||||
import copy
|
|
||||||
|
|
||||||
shadow_messages = copy.deepcopy(messages)
|
|
||||||
shadow_service = MemoryService(
|
|
||||||
memory_config=memory_config,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
|
||||||
shadow_result = await shadow_service.write(
|
|
||||||
messages=shadow_messages,
|
|
||||||
language=language,
|
|
||||||
ref_id='',
|
|
||||||
is_pilot_run=True,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
|
|
||||||
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
|
|
||||||
f"extraction={shadow_result.extraction}"
|
|
||||||
)
|
|
||||||
except Exception as shadow_err:
|
|
||||||
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
|
|
||||||
# ── 影子运行结束 ──
|
|
||||||
for lang in ["zh", "en"]:
|
|
||||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
|
||||||
end_user_id, lang
|
|
||||||
)
|
|
||||||
if deleted:
|
|
||||||
logger.info(
|
|
||||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
message["file_content"] = [
|
if isinstance(message, dict):
|
||||||
perceptual[0].file_path for perceptual in message["file_content"]
|
message["file_content"] = [
|
||||||
]
|
perceptual[0].file_path for perceptual in (message["file_content"] or [])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
message.file_content = [
|
||||||
|
perceptual[0].file_path for perceptual in (message.file_content or [])
|
||||||
|
]
|
||||||
return self.writer_messages_deal(
|
return self.writer_messages_deal(
|
||||||
"success",
|
"success",
|
||||||
start_time,
|
start_time,
|
||||||
end_user_id,
|
end_user_id,
|
||||||
config_id,
|
memory_config.config_id,
|
||||||
message_text,
|
message_text,
|
||||||
{
|
{
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@@ -443,15 +337,139 @@ class MemoryAgentService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
audit_logger.log_operation(
|
||||||
duration = time.time() - start_time
|
operation="WRITE",
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
config_id=memory_config.config_id,
|
||||||
success=False, duration=duration, error=error_msg)
|
end_user_id=end_user_id,
|
||||||
|
success=False,
|
||||||
|
duration=time.time() - start_time,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
async def _resolve_and_load_config(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session,
|
||||||
|
start_time: float,
|
||||||
|
):
|
||||||
|
"""解析 end_user 关联配置并从数据库加载完整 memory_config。"""
|
||||||
|
workspace_id = None
|
||||||
|
try:
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id is None:
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
|
if config_id is None and workspace_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. "
|
||||||
|
f"Please ensure the user has a connected memory configuration."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "No memory configuration found" in str(e):
|
||||||
|
raise
|
||||||
|
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||||
|
if config_id is None:
|
||||||
|
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_db_context() as config_db:
|
||||||
|
memory_config = MemoryConfigService(config_db).load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
service_name="MemoryAgentService",
|
||||||
|
)
|
||||||
|
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||||
|
return memory_config
|
||||||
|
except ConfigurationError as e:
|
||||||
|
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
audit_logger.log_operation(
|
||||||
|
operation="WRITE",
|
||||||
|
config_id=config_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
success=False,
|
||||||
|
duration=time.time() - start_time,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
async def _preprocess_files(
|
||||||
|
self,
|
||||||
|
messages: list[MessageItem] | list[dict],
|
||||||
|
end_user_id: str,
|
||||||
|
memory_config,
|
||||||
|
db: Session,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""处理消息中附带的文件,生成感知记忆对象并挂载到 message['file_content']。"""
|
||||||
|
perceptual_service = MemoryPerceptualService(db)
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
message["file_content"] = []
|
||||||
|
files = message.get("files") or []
|
||||||
|
else:
|
||||||
|
message.file_content = []
|
||||||
|
files = message.files or []
|
||||||
|
for file in files:
|
||||||
|
file_object = await perceptual_service.generate_perceptual_memory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
file=FileInput(**file),
|
||||||
|
)
|
||||||
|
if file_object is None:
|
||||||
|
continue
|
||||||
|
if isinstance(message, dict):
|
||||||
|
message["file_content"].append((file_object, file["type"]))
|
||||||
|
else:
|
||||||
|
message.file_content.append((file_object, file["type"]))
|
||||||
|
logger.info(messages)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def _write_neo4j(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[MessageItem] | list[dict],
|
||||||
|
memory_config,
|
||||||
|
language: Language | str,
|
||||||
|
) -> None:
|
||||||
|
"""根据 NEW_PIPELINE_ENABLED 选择新旧流水线写入 Neo4j。"""
|
||||||
|
# 统一转换为 dict,下游流水线期望 list[dict]
|
||||||
|
messages_dict = [
|
||||||
|
msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True)
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
|
||||||
|
|
||||||
|
if use_new_pipeline:
|
||||||
|
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
|
||||||
|
result = await service.write(messages=messages_dict, language=language, ref_id='')
|
||||||
|
logger.info(
|
||||||
|
f"[NewPipeline] 完成: status={result.status}, "
|
||||||
|
f"elapsed={result.elapsed_seconds:.2f}s, "
|
||||||
|
f"extraction={result.extraction}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await write_neo4j(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=messages_dict,
|
||||||
|
memory_config=memory_config,
|
||||||
|
ref_id='',
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _invalidate_interest_cache(self, end_user_id: str) -> None:
|
||||||
|
"""写入完成后失效兴趣分布缓存。"""
|
||||||
|
for lang in ["zh", "en"]:
|
||||||
|
deleted = await InterestMemoryCache.delete_interest_distribution(end_user_id, lang)
|
||||||
|
if deleted:
|
||||||
|
logger.info(
|
||||||
|
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}"
|
||||||
|
)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.core.logging_config import get_logger
|
|||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
|
from app.schemas.memory_agent_schema import WriteMemoryRequest
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -291,12 +292,14 @@ class MemoryAPIService:
|
|||||||
try:
|
try:
|
||||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
result = await MemoryAgentService().write_memory(
|
result = await MemoryAgentService().write_memory(
|
||||||
end_user_id=end_user_id,
|
WriteMemoryRequest(
|
||||||
messages=messages,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id,
|
messages=messages,
|
||||||
db=self.db,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id or "",
|
user_rag_memory_id=user_rag_memory_id or "",
|
||||||
|
),
|
||||||
|
self.db,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write (sync) successful for end_user: {end_user_id}")
|
logger.info(f"Memory write (sync) successful for end_user: {end_user_id}")
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from app.models import Document, File, Knowledge
|
|||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas import document_schema, file_schema
|
from app.schemas import document_schema, file_schema
|
||||||
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
||||||
|
from app.schemas.memory_agent_schema import WriteMemoryRequest
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
from app.utils.redis_lock import RedisFairLock
|
from app.utils.redis_lock import RedisFairLock
|
||||||
@@ -1280,8 +1281,17 @@ def write_message_task(
|
|||||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
result = await service.write_memory(
|
||||||
user_rag_memory_id, language)
|
WriteMemoryRequest(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=message,
|
||||||
|
config_id=actual_config_id,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
language=language,
|
||||||
|
),
|
||||||
|
db,
|
||||||
|
)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user