feat(memory): add async user metadata extraction pipeline
- Add MetadataExtractor to collect user-related statements post-dedup and extract profile/behavioral metadata via independent LLM call - Add Celery task (extract_user_metadata) routed to memory_tasks queue - Add metadata models (UserMetadata, UserMetadataProfile, etc.) - Add metadata utility functions (clean, validate, merge with _op support) - Add Jinja2 prompt template for metadata extraction (zh/en) - Fix Lucene query parameter naming: rename `q` to `query` across all Cypher queries, graph_search functions, and callers - Escape `/` in Lucene queries to prevent TokenMgrError - Add `speaker` field to ChunkNode and persist it in Neo4j - Remove unused imports (argparse, os, UUID) in search.py - Fix unnecessary db context nesting in interest distribution task
This commit is contained in:
@@ -111,6 +111,9 @@ celery_app.conf.update(
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Metadata extraction → memory_tasks queue
|
||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
@@ -153,7 +153,7 @@ class PerceptualSearchService:
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class PerceptualSearchService:
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
@@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# User metadata models
|
||||
from app.core.memory.models.metadata_models import (
|
||||
UserMetadata,
|
||||
UserMetadataBehavioralHints,
|
||||
UserMetadataProfile,
|
||||
MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
@@ -124,6 +132,10 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
"UserMetadata",
|
||||
"UserMetadataBehavioralHints",
|
||||
"UserMetadataProfile",
|
||||
"MetadataExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
|
||||
@@ -364,12 +364,14 @@ class ChunkNode(Node):
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
content: The text content of the chunk
|
||||
speaker: Speaker identifier ('user' or 'assistant')
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
metadata: Additional chunk metadata as key-value pairs
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
content: str = Field(..., description="The text content of the chunk")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||
|
||||
40
api/app/core/memory/models/metadata_models.py
Normal file
40
api/app/core/memory/models/metadata_models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Models for user metadata extraction.
|
||||
|
||||
Independent from triplet_models.py - these models are used by the
|
||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserMetadataProfile(BaseModel):
|
||||
"""用户画像信息"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
role: str = Field(default="", description="用户职业或角色,如 teacher, doctor, software_engineer")
|
||||
domain: str = Field(default="", description="用户所在领域,如 education, healthcare, software_development")
|
||||
expertise: List[str] = Field(default_factory=list, description="用户擅长的技能或工具")
|
||||
interests: List[str] = Field(default_factory=list, description="用户关注的话题或领域标签")
|
||||
|
||||
|
||||
class UserMetadataBehavioralHints(BaseModel):
|
||||
"""行为偏好"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
learning_stage: str = Field(default="", description="学习阶段")
|
||||
preferred_depth: str = Field(default="", description="偏好深度")
|
||||
tone_preference: str = Field(default="", description="语气偏好")
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
"""用户元数据顶层结构"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||
behavioral_hints: UserMetadataBehavioralHints = Field(default_factory=UserMetadataBehavioralHints)
|
||||
knowledge_tags: List[str] = Field(default_factory=list, description="知识标签")
|
||||
|
||||
|
||||
class MetadataExtractionResponse(BaseModel):
|
||||
"""元数据提取 LLM 响应结构"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
user_metadata: UserMetadata = Field(default_factory=UserMetadata)
|
||||
@@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
@@ -6,7 +5,6 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import (
|
||||
)
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
# from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
@@ -748,11 +746,10 @@ async def run_hybrid_search(
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
@@ -762,7 +759,6 @@ async def run_hybrid_search(
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
@@ -904,10 +900,10 @@ async def run_hybrid_search(
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
logger.info("[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
@@ -311,8 +311,35 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
# 步骤 7: 同步用户别名到数据库表 + 触发异步元数据提取(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
# 收集用户相关 statement 并触发异步元数据提取
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor
|
||||
metadata_extractor = MetadataExtractor(llm_client=self.llm_client, language=self.language)
|
||||
user_statements = metadata_extractor.collect_user_related_statements(
|
||||
entity_nodes, statement_nodes,
|
||||
statement_entity_edges
|
||||
)
|
||||
if user_statements:
|
||||
# 获取 end_user_id 和 config_id
|
||||
end_user_id = dialog_data_list[0].end_user_id if dialog_data_list else None
|
||||
config_id = dialog_data_list[0].config_id if dialog_data_list and hasattr(dialog_data_list[0], 'config_id') else None
|
||||
if end_user_id:
|
||||
from app.tasks import extract_user_metadata_task
|
||||
extract_user_metadata_task.delay(
|
||||
end_user_id=str(end_user_id),
|
||||
statements=user_statements,
|
||||
config_id=str(config_id) if config_id else None,
|
||||
language=self.language,
|
||||
)
|
||||
logger.info(f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement")
|
||||
else:
|
||||
logger.info("未找到用户相关 statement,跳过元数据提取")
|
||||
except Exception as e:
|
||||
logger.error(f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
# 同步用户别名到数据库表
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
|
||||
@@ -1107,6 +1134,7 @@ class ExtractionOrchestrator:
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
speaker=getattr(chunk, 'speaker', None),
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
||||
created_at=dialog_data.created_at,
|
||||
@@ -1342,7 +1370,7 @@ class ExtractionOrchestrator:
|
||||
async def _update_end_user_other_name(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
dialog_data_list: List[DialogData]
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> None:
|
||||
"""
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
@@ -1470,7 +1498,6 @@ class ExtractionOrchestrator:
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
aliases=merged_aliases,
|
||||
meta_data={}
|
||||
))
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}")
|
||||
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Metadata extractor module.
|
||||
|
||||
Collects user-related statements from post-dedup graph data and
|
||||
extracts user metadata via an independent LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.metadata_models import (
|
||||
MetadataExtractionResponse,
|
||||
UserMetadata,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse the same user-entity detection logic from dedup module
|
||||
_USER_NAMES = {"用户", "我", "user", "i"}
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Extracts user metadata from post-dedup graph data via independent LLM call."""
|
||||
|
||||
def __init__(self, llm_client, language: str = "zh"):
|
||||
self.llm_client = llm_client
|
||||
self.language = language
|
||||
|
||||
@staticmethod
|
||||
def detect_language(statements: List[str]) -> str:
|
||||
"""根据 statement 文本内容检测语言。
|
||||
如果文本中包含中文字符则返回 "zh",否则返回 "en"。
|
||||
"""
|
||||
import re
|
||||
combined = " ".join(statements)
|
||||
if re.search(r'[\u4e00-\u9fff]', combined):
|
||||
return "zh"
|
||||
return "en"
|
||||
|
||||
def collect_user_related_statements(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
) -> List[str]:
|
||||
"""
|
||||
从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。
|
||||
|
||||
筛选逻辑:
|
||||
1. 用户实体 → StatementEntityEdge → statement(直接关联)
|
||||
2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声)
|
||||
|
||||
Returns:
|
||||
用户发言的 statement 文本列表
|
||||
"""
|
||||
# Find user entity IDs
|
||||
user_entity_ids = set()
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
user_entity_ids.add(ent.id)
|
||||
|
||||
if not user_entity_ids:
|
||||
logger.debug("未找到用户实体节点,跳过 statement 收集")
|
||||
return []
|
||||
|
||||
# 用户实体 → StatementEntityEdge → statement
|
||||
target_stmt_ids = set()
|
||||
for edge in statement_entity_edges:
|
||||
if edge.target in user_entity_ids:
|
||||
target_stmt_ids.add(edge.source)
|
||||
|
||||
# Collect: only speaker="user" statements, preserving order
|
||||
result = []
|
||||
seen = set()
|
||||
total_associated = 0
|
||||
skipped_non_user = 0
|
||||
for stmt_node in statement_nodes:
|
||||
if stmt_node.id in target_stmt_ids and stmt_node.id not in seen:
|
||||
total_associated += 1
|
||||
speaker = getattr(stmt_node, 'speaker', None) or 'unknown'
|
||||
if speaker == "user":
|
||||
text = (stmt_node.statement or "").strip()
|
||||
if text:
|
||||
result.append(text)
|
||||
else:
|
||||
skipped_non_user += 1
|
||||
seen.add(stmt_node.id)
|
||||
|
||||
logger.info(
|
||||
f"收集到 {len(result)} 条用户发言 statement "
|
||||
f"(直接关联: {total_associated}, speaker=user: {len(result)}, "
|
||||
f"跳过非user: {skipped_non_user})"
|
||||
)
|
||||
if total_associated > 0 and len(result) == 0:
|
||||
logger.warning(
|
||||
f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤,"
|
||||
f"可能本次写入不包含 user 消息"
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_metadata(self, statements: List[str]) -> Optional[UserMetadata]:
|
||||
"""
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据。
|
||||
语言根据 statement 内容自动检测,不依赖系统界面语言。
|
||||
|
||||
Returns:
|
||||
UserMetadata on success, None on failure
|
||||
"""
|
||||
if not statements:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env
|
||||
|
||||
# 根据写入内容的语言自动检测,而非使用系统界面语言
|
||||
detected_language = self.detect_language(statements)
|
||||
logger.info(f"元数据提取语言检测结果: {detected_language}")
|
||||
|
||||
template = prompt_env.get_template("extract_user_metadata.jinja2")
|
||||
prompt = template.render(
|
||||
statements=statements,
|
||||
language=detected_language,
|
||||
json_schema="",
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
if response and response.user_metadata:
|
||||
return response.user_metadata
|
||||
|
||||
logger.warning("LLM 返回的元数据为空")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
使用Neo4j的全文索引进行高效的文本匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
@@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
# 调用底层的关键词搜索函数
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
|
||||
@@ -22,7 +22,9 @@ def escape_lucene_query(query: str) -> str:
|
||||
s = s.replace("\r", " ").replace("\n", " ").strip()
|
||||
|
||||
# Lucene reserved tokens/special characters
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
|
||||
# NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent
|
||||
# TokenMgrError when the query contains unmatched slashes.
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
|
||||
# Replace longer tokens first to avoid partial double-escaping
|
||||
for token in sorted(specials, key=len, reverse=True):
|
||||
s = s.replace(token, f"\\{token}")
|
||||
|
||||
179
api/app/core/memory/utils/metadata_utils.py
Normal file
179
api/app/core/memory/utils/metadata_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Metadata utility functions for cleaning, validating, aggregating, and merging
|
||||
user metadata extracted from conversations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from app.core.memory.models.metadata_models import UserMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def clean_metadata(raw: dict) -> dict:
|
||||
"""
|
||||
Clean metadata by removing empty string values and empty array fields recursively.
|
||||
Only keeps fields with actual content. If a nested dict becomes empty after cleaning,
|
||||
it is removed too.
|
||||
"""
|
||||
cleaned = {}
|
||||
for key, value in raw.items():
|
||||
if isinstance(value, dict):
|
||||
nested = clean_metadata(value)
|
||||
if nested:
|
||||
cleaned[key] = nested
|
||||
elif isinstance(value, list):
|
||||
if len(value) > 0:
|
||||
cleaned[key] = value
|
||||
elif isinstance(value, str):
|
||||
if value != "":
|
||||
cleaned[key] = value
|
||||
else:
|
||||
cleaned[key] = value
|
||||
return cleaned
|
||||
|
||||
# TODO 这个函数没有调用的地方
|
||||
def validate_metadata(raw: dict) -> Optional[UserMetadata]:
|
||||
"""
|
||||
Validate metadata structure using the Pydantic UserMetadata model.
|
||||
Returns None and logs a WARNING on validation failure.
|
||||
"""
|
||||
try:
|
||||
return UserMetadata.model_validate(raw)
|
||||
except Exception as e:
|
||||
logger.warning("Metadata validation failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def merge_metadata(existing: dict, new: dict) -> dict:
|
||||
"""
|
||||
Merge new extracted metadata with existing database metadata.
|
||||
- Scalar fields: new value overwrites old value
|
||||
- Array fields: support _op marker (append/replace/remove)
|
||||
- Missing top-level keys in new: preserve existing data
|
||||
- Auto-update _updated_at timestamp dict with field paths and ISO timestamps
|
||||
- When existing is None or {}: directly write new + _updated_at (no merge logic)
|
||||
"""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if not existing:
|
||||
# Direct write: new + _updated_at for all fields
|
||||
result = dict(new)
|
||||
updated_at = {}
|
||||
_collect_field_paths(result, "", updated_at, now)
|
||||
if updated_at:
|
||||
result["_updated_at"] = updated_at
|
||||
return result
|
||||
|
||||
result = dict(existing)
|
||||
updated_at: dict = dict(result.get("_updated_at", {}))
|
||||
|
||||
for key, new_value in new.items():
|
||||
if key == "_updated_at":
|
||||
continue
|
||||
|
||||
old_value = result.get(key)
|
||||
|
||||
if isinstance(new_value, dict) and isinstance(old_value, dict):
|
||||
# Nested dict merge (e.g. profile, behavioral_hints)
|
||||
_merge_nested(result, key, old_value, new_value, updated_at, now)
|
||||
elif isinstance(new_value, list) or (isinstance(new_value, dict) and "_op" in new_value):
|
||||
# Array field with possible _op
|
||||
_merge_array_field(result, key, old_value, new_value, updated_at, now)
|
||||
else:
|
||||
# Scalar top-level field
|
||||
if old_value != new_value:
|
||||
result[key] = new_value
|
||||
updated_at[key] = now
|
||||
# If equal, no change needed
|
||||
|
||||
result["_updated_at"] = updated_at
|
||||
return result
|
||||
|
||||
# TODO 考虑大函数包含小函数,因为只服务于大函数,实现代码文件的结构清楚
|
||||
def _collect_field_paths(data: dict, prefix: str, updated_at: dict, now: str) -> None:
|
||||
"""Collect all leaf field paths for _updated_at on direct write."""
|
||||
for key, value in data.items():
|
||||
if key == "_updated_at":
|
||||
continue
|
||||
path = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
|
||||
if isinstance(value, dict):
|
||||
_collect_field_paths(value, path, updated_at, now)
|
||||
else:
|
||||
updated_at[path] = now
|
||||
|
||||
|
||||
def _merge_nested(
|
||||
result: dict, key: str, old_dict: dict, new_dict: dict,
|
||||
updated_at: dict, now: str
|
||||
) -> None:
|
||||
"""Merge a nested dict (e.g. profile, behavioral_hints)."""
|
||||
merged = dict(old_dict)
|
||||
for field, new_val in new_dict.items():
|
||||
old_val = merged.get(field)
|
||||
path = f"{key}.{field}"
|
||||
|
||||
if isinstance(new_val, list) or (isinstance(new_val, dict) and "_op" in new_val):
|
||||
_merge_array_field_inner(merged, field, old_val, new_val, updated_at, path, now)
|
||||
else:
|
||||
# Scalar field
|
||||
if old_val != new_val:
|
||||
merged[field] = new_val
|
||||
updated_at[path] = now
|
||||
result[key] = merged
|
||||
|
||||
|
||||
def _merge_array_field(
|
||||
result: dict, key: str, old_value, new_value,
|
||||
updated_at: dict, now: str
|
||||
) -> None:
|
||||
"""Merge a top-level array field with _op support."""
|
||||
_merge_array_field_inner(result, key, old_value, new_value, updated_at, key, now)
|
||||
|
||||
|
||||
def _merge_array_field_inner(
|
||||
container: dict, field: str, old_value, new_value,
|
||||
updated_at: dict, path: str, now: str
|
||||
) -> None:
|
||||
"""Core array merge logic with _op support."""
|
||||
# Determine op and items
|
||||
if isinstance(new_value, dict) and "_op" in new_value:
|
||||
op = new_value.get("_op", "append")
|
||||
items = new_value.get(field, new_value.get("items", []))
|
||||
# If the dict has a key matching the field name, use it; otherwise look for list values
|
||||
if not isinstance(items, list):
|
||||
# Try to find the list value in the dict (excluding _op)
|
||||
for k, v in new_value.items():
|
||||
if k != "_op" and isinstance(v, list):
|
||||
items = v
|
||||
break
|
||||
else:
|
||||
items = []
|
||||
elif isinstance(new_value, list):
|
||||
op = "append"
|
||||
items = new_value
|
||||
else:
|
||||
op = "append"
|
||||
items = []
|
||||
|
||||
old_arr = old_value if isinstance(old_value, list) else []
|
||||
|
||||
if op == "replace":
|
||||
new_arr = items
|
||||
elif op == "remove":
|
||||
new_arr = [x for x in old_arr if x not in items]
|
||||
else:
|
||||
# append (default): merge and deduplicate
|
||||
seen = list(old_arr)
|
||||
for item in items:
|
||||
if item not in seen:
|
||||
seen.append(item)
|
||||
new_arr = seen
|
||||
|
||||
if old_arr != new_arr:
|
||||
container[field] = new_arr
|
||||
updated_at[path] = now
|
||||
else:
|
||||
container[field] = new_arr
|
||||
@@ -406,4 +406,12 @@ Output:
|
||||
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||
|
||||
**Output JSON structure:**
|
||||
```json
|
||||
{
|
||||
"triplets": [...],
|
||||
"entities": [...]
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
===Task===
|
||||
Extract user metadata from the following conversation statements spoken by the user.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**"三度原则"判断标准:**
|
||||
- 复用度:该信息是否会被多个功能模块使用?
|
||||
- 约束度:该信息是否会影响系统行为?
|
||||
- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。
|
||||
|
||||
**提取规则:**
|
||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||
- 仅提取文本中明确提到的信息,不要推测
|
||||
- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象
|
||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||
|
||||
**字段说明:**
|
||||
- profile.role:用户的职业或角色,如 教师、医生、后端工程师
|
||||
- profile.domain:用户所在领域,如 教育、医疗、软件开发
|
||||
- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签
|
||||
- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级)
|
||||
- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨)
|
||||
- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨)
|
||||
- knowledge_tags:用户涉及的知识领域标签
|
||||
{% else %}
|
||||
**"Three-Degree Principle" criteria:**
|
||||
- Reusability: Will this information be used by multiple functional modules?
|
||||
- Constraint: Will this information affect system behavior?
|
||||
- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information.
|
||||
|
||||
**Extraction rules:**
|
||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||
- Only extract information explicitly mentioned in the text, do not speculate
|
||||
- If no user profile information can be extracted, return an empty user_metadata object
|
||||
- **Output language must match the input text language**
|
||||
|
||||
**Field descriptions:**
|
||||
- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer
|
||||
- profile.domain: User's domain, e.g. education, healthcare, software development
|
||||
- profile.expertise: User's skills or tools (general, not limited to programming)
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in
|
||||
- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced)
|
||||
- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive)
|
||||
- behavioral_hints.tone_preference: Tone preference (casual/professional/academic)
|
||||
- knowledge_tags: Knowledge domain tags related to the user
|
||||
{% endif %}
|
||||
|
||||
===User Statements===
|
||||
{% for stmt in statements %}
|
||||
- {{ stmt }}
|
||||
{% endfor %}
|
||||
|
||||
===Output Format===
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"user_metadata": {
|
||||
"profile": {
|
||||
"role": "",
|
||||
"domain": "",
|
||||
"expertise": [],
|
||||
"interests": []
|
||||
},
|
||||
"behavioral_hints": {
|
||||
"learning_stage": "",
|
||||
"preferred_depth": "",
|
||||
"tone_preference": ""
|
||||
},
|
||||
"knowledge_tags": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -23,6 +23,7 @@ SET s += {
|
||||
end_user_id: statement.end_user_id,
|
||||
stmt_type: statement.stmt_type,
|
||||
statement: statement.statement,
|
||||
speaker: statement.speaker,
|
||||
emotion_intensity: statement.emotion_intensity,
|
||||
emotion_target: statement.emotion_target,
|
||||
emotion_subject: statement.emotion_subject,
|
||||
@@ -56,6 +57,7 @@ SET c += {
|
||||
expired_at: chunk.expired_at,
|
||||
dialog_id: chunk.dialog_id,
|
||||
content: chunk.content,
|
||||
speaker: chunk.speaker,
|
||||
chunk_embedding: chunk.chunk_embedding,
|
||||
sequence_number: chunk.sequence_number,
|
||||
start_index: chunk.start_index,
|
||||
@@ -283,7 +285,7 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
@@ -307,7 +309,7 @@ LIMIT $limit
|
||||
"""
|
||||
# 查询实体名称包含指定字符串的实体
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
@@ -337,21 +339,21 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
WITH e, score
|
||||
WITH collect({entity: e, score: score}) AS fulltextResults
|
||||
With collect({entity: e, score: score}) AS fulltextResults
|
||||
|
||||
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||
AND ae.aliases IS NOT NULL
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
|
||||
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||
|
||||
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||
CASE
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
|
||||
ELSE 0.8
|
||||
END
|
||||
}]) AS row
|
||||
@@ -384,7 +386,7 @@ LIMIT $limit
|
||||
|
||||
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
@@ -501,7 +503,7 @@ LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||
@@ -677,7 +679,7 @@ SET n.invalid_at = $new_invalid_at
|
||||
|
||||
# MemorySummary keyword search using fulltext index
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
|
||||
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||
RETURN m.id AS id,
|
||||
@@ -1363,7 +1365,7 @@ RETURN c.community_id AS community_id
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
|
||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
RETURN c.community_id AS id,
|
||||
c.name AS name,
|
||||
@@ -1451,7 +1453,7 @@ RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
@@ -87,7 +88,7 @@ async def _update_activation_values_batch(
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
logger.warning("批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 记录去重信息(仅针对具有有效 ID 的节点)
|
||||
@@ -223,7 +224,7 @@ async def _update_search_results_activation(
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
query: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
@@ -234,14 +235,14 @@ async def search_graph(
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
- Chunks: matches s.content CONTAINS q (from Statement nodes)
|
||||
- Summaries: matches ms.content CONTAINS q
|
||||
- Statements: matches s.statement CONTAINS query
|
||||
- Entities: matches e.name CONTAINS query
|
||||
- Chunks: matches s.content CONTAINS query (from Statement nodes)
|
||||
- Summaries: matches ms.content CONTAINS query
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
query: Query text for full-text search
|
||||
end_user_id: Optional group filter
|
||||
limit: Max results per category
|
||||
include: List of categories to search (default: all)
|
||||
@@ -252,6 +253,9 @@ async def search_graph(
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Escape Lucene special characters to prevent query parse errors
|
||||
escaped_query = escape_lucene_query(query)
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
@@ -260,7 +264,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
@@ -270,7 +274,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
json_format=True,
|
||||
q=q,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
@@ -280,7 +284,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
json_format=True,
|
||||
q=q,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
@@ -290,7 +294,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
@@ -300,7 +304,7 @@ async def search_graph(
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
@@ -482,7 +486,7 @@ async def search_graph_by_embedding(
|
||||
update_time = time.time() - update_start
|
||||
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
else:
|
||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||
logger.info("[PERF] Skipping activation updates (only summaries)")
|
||||
|
||||
return results
|
||||
|
||||
@@ -520,7 +524,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
# 全文索引按名称检索(包含 CONTAINS 语义)
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name,
|
||||
query=escape_lucene_query(name),
|
||||
end_user_id=end_user_id,
|
||||
limit=100,
|
||||
)
|
||||
@@ -544,7 +548,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
try:
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name.lower(),
|
||||
query=escape_lucene_query(name.lower()),
|
||||
end_user_id=end_user_id,
|
||||
limit=100,
|
||||
)
|
||||
@@ -593,11 +597,12 @@ async def search_graph_by_keyword_temporal(
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
logger.warning(f"query_text不能为空")
|
||||
logger.warning("query_text不能为空")
|
||||
return {"statements": []}
|
||||
escaped_query = escape_lucene_query(query_text)
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
q=query_text,
|
||||
query=escaped_query,
|
||||
end_user_id=end_user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
@@ -671,7 +676,7 @@ async def search_graph_by_dialog_id(
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
logger.warning(f"dialog_id不能为空")
|
||||
logger.warning("dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
@@ -690,7 +695,7 @@ async def search_graph_by_chunk_id(
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
logger.warning(f"chunk_id不能为空")
|
||||
logger.warning("chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -968,7 +973,7 @@ async def search_graph_l_valid_at(
|
||||
|
||||
async def search_perceptual(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
query: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
@@ -979,7 +984,7 @@ async def search_perceptual(
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
query: Query text for full-text search
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
@@ -989,7 +994,7 @@ async def search_perceptual(
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
q=q,
|
||||
query=escape_lucene_query(query),
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
189
api/app/tasks.py
189
api/app/tasks.py
@@ -1001,7 +1001,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
except Exception as e:
|
||||
print(f"\n\nError during fetch feishu: {e}")
|
||||
case _: # General
|
||||
print(f"General: No synchronization needed\n")
|
||||
print("General: No synchronization needed\n")
|
||||
|
||||
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
|
||||
return result
|
||||
@@ -1510,6 +1510,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"status": "SUCCESS",
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"end_user_details": end_user_details,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
@@ -2602,35 +2603,34 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
|
||||
service = MemoryAgentService()
|
||||
|
||||
with get_db_context() as db:
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
|
||||
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
@@ -2914,4 +2914,139 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
}
|
||||
|
||||
|
||||
# ─── User Metadata Extraction Task ───────────────────────────────────────────
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name='app.tasks.extract_user_metadata',
|
||||
ignore_result=False,
|
||||
max_retries=0,
|
||||
acks_late=True,
|
||||
time_limit=300,
|
||||
soft_time_limit=240,
|
||||
)
|
||||
def extract_user_metadata_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
statements: List[str],
|
||||
config_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> Dict[str, Any]:
|
||||
"""异步提取用户元数据并写入数据库。
|
||||
|
||||
在去重消歧完成后由编排器触发,使用独立 LLM 调用提取元数据。
|
||||
LLM 配置优先使用 config_id 对应的应用配置,失败时回退到工作空间默认配置。
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户 ID
|
||||
statements: 用户相关的 statement 文本列表
|
||||
config_id: 应用配置 ID(可选)
|
||||
language: 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
f"[CELERY METADATA] Starting metadata extraction - end_user_id={end_user_id}, "
|
||||
f"statements_count={len(statements)}, config_id={config_id}, language={language}"
|
||||
)
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor
|
||||
from app.core.memory.utils.metadata_utils import clean_metadata, merge_metadata, validate_metadata
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 1. 获取 LLM 配置(应用配置 → 工作空间配置兜底)并创建 LLM client
|
||||
with get_db_context() as db:
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
|
||||
# 获取 workspace_id from end_user
|
||||
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
|
||||
if not end_user:
|
||||
return {"status": "FAILURE", "error": f"End user not found: {end_user_id}"}
|
||||
|
||||
workspace_id = end_user.workspace_id
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.get_config_with_fallback(
|
||||
memory_config_id=uuid.UUID(config_id) if config_id else None,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
if not memory_config:
|
||||
return {"status": "FAILURE", "error": "No LLM config available (app + workspace fallback failed)"}
|
||||
|
||||
# 2. 创建 LLM client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
factory = MemoryClientFactory(db)
|
||||
if not memory_config.llm_id:
|
||||
return {"status": "FAILURE", "error": "Memory config has no LLM model configured"}
|
||||
llm_client = factory.get_llm_client(memory_config.llm_id)
|
||||
|
||||
# 3. 提取元数据
|
||||
extractor = MetadataExtractor(llm_client=llm_client, language=language)
|
||||
user_metadata = await extractor.extract_metadata(statements)
|
||||
|
||||
if not user_metadata:
|
||||
logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}")
|
||||
return {"status": "SUCCESS", "result": "no_metadata_extracted"}
|
||||
|
||||
# 4. 清洗、校验、合并、写入
|
||||
raw_dict = user_metadata.model_dump()
|
||||
cleaned = clean_metadata(raw_dict)
|
||||
if not cleaned:
|
||||
logger.info(f"[CELERY METADATA] Cleaned metadata is empty for end_user_id={end_user_id}")
|
||||
return {"status": "SUCCESS", "result": "empty_after_cleaning"}
|
||||
|
||||
validated = validate_metadata(cleaned)
|
||||
if not validated:
|
||||
return {"status": "FAILURE", "error": "Metadata validation failed after cleaning"}
|
||||
|
||||
with get_db_context() as db:
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
|
||||
if info:
|
||||
existing_meta = info.meta_data if info.meta_data else {}
|
||||
info.meta_data = merge_metadata(existing_meta, cleaned)
|
||||
logger.info(f"[CELERY METADATA] Updated metadata for end_user_id={end_user_id}")
|
||||
else:
|
||||
# No end_user_info record yet - metadata will be written when alias sync creates it,
|
||||
# or we create a minimal record here
|
||||
logger.info(
|
||||
f"[CELERY METADATA] No end_user_info record for end_user_id={end_user_id}, "
|
||||
f"skipping metadata write (will be created by alias sync)"
|
||||
)
|
||||
return {"status": "SUCCESS", "result": "no_info_record"}
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"status": "SUCCESS", "result": "metadata_written"}
|
||||
|
||||
loop = None
|
||||
try:
|
||||
loop = set_asyncio_event_loop()
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed
|
||||
result["task_id"] = self.request.id
|
||||
logger.info(f"[CELERY METADATA] Task completed - elapsed={elapsed:.2f}s, result={result.get('result')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"[CELERY METADATA] Task failed - elapsed={elapsed:.2f}s, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
finally:
|
||||
if loop:
|
||||
_shutdown_loop_gracefully(loop)
|
||||
|
||||
|
||||
# unused task
|
||||
Reference in New Issue
Block a user