Merge pull request #671 from SuanmoSuanyangTechnology/fix/log-community
【change】 1.Standardize log specifications;2.Cluster settings trigger …
This commit is contained in:
@@ -77,6 +77,7 @@ celery_app.conf.update(
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
@@ -169,8 +169,8 @@ async def write(
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(
|
||||
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
||||
await _trigger_clustering_sync(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
|
||||
@@ -71,13 +71,11 @@ class LabelPropagationEngine:
|
||||
connector: Neo4jConnector,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
return
|
||||
|
||||
# 统计邻居社区分布
|
||||
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata([target_cid], end_user_id)
|
||||
# 新实体加入后成员变化,强制重新生成元数据
|
||||
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据。
|
||||
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
"""
|
||||
if not community_ids:
|
||||
return
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||
async def _build_one(cid: str):
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
async def _build_one(cid: str) -> Optional[Dict]:
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||
return None
|
||||
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
|
||||
# 方案四:注入社区内实体间关系三元组
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
name, summary = "", ""
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
if self.embedding_model_id:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
if len(metadata_list) == 1:
|
||||
@@ -558,17 +576,13 @@ class LabelPropagationEngine:
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||
if not result:
|
||||
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if ok:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
return str(uuid.uuid4())
|
||||
@@ -9,6 +9,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
@@ -706,7 +709,7 @@ class SemanticPruner:
|
||||
# 阈值保护:最高0.9
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
if proportion > 0.9:
|
||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
@@ -905,7 +908,7 @@ class SemanticPruner:
|
||||
|
||||
# Safety: avoid empty dataset
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
@@ -915,8 +918,7 @@ class SemanticPruner:
|
||||
try:
|
||||
self.run_logs.append(msg)
|
||||
except Exception:
|
||||
# 任何异常都不影响打印
|
||||
pass
|
||||
print(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
logger.debug("=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
logger.debug("=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
logger.debug("=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
@@ -217,10 +220,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
return created_ids
|
||||
except Exception as e:
|
||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ class CommunityRepository:
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
logger.error(f"update_community_metadata failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def batch_update_community_metadata(
|
||||
|
||||
@@ -1069,6 +1069,7 @@ Graph_Node_query = """
|
||||
|
||||
COMMUNITY_NODE_UPSERT = """
|
||||
MERGE (c:Community {community_id: $community_id})
|
||||
ON CREATE SET c.id = $community_id
|
||||
SET c.end_user_id = $end_user_id,
|
||||
c.member_count = $member_count,
|
||||
c.updated_at = datetime()
|
||||
@@ -1175,7 +1176,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
SET c.id = coalesce(c.id, $community_id),
|
||||
c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.summary_embedding = $summary_embedding,
|
||||
@@ -1186,7 +1188,8 @@ RETURN c.community_id AS community_id
|
||||
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||
UNWIND $communities AS row
|
||||
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||
SET c.name = row.name,
|
||||
SET c.id = coalesce(c.id, row.community_id),
|
||||
c.name = row.name,
|
||||
c.summary = row.summary,
|
||||
c.core_entities = row.core_entities,
|
||||
c.summary_embedding = row.summary_embedding,
|
||||
@@ -1270,6 +1273,40 @@ RETURN
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL AND
|
||||
c.summary_embedding IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.name = ''
|
||||
OR c.summary IS NULL OR c.summary = ''
|
||||
OR c.core_entities IS NULL
|
||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
@@ -1325,39 +1362,4 @@ RETURN s.statement AS statement,
|
||||
c.name AS community_name
|
||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL AND
|
||||
c.summary_embedding IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.name = ''
|
||||
OR c.summary IS NULL OR c.summary = ''
|
||||
OR c.core_entities IS NULL
|
||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
"""
|
||||
@@ -162,7 +162,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||
schedule_clustering_after_write() 显式触发。
|
||||
_trigger_clustering_sync() 显式触发。
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
@@ -303,16 +303,13 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
async def _trigger_clustering_sync(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
@@ -324,8 +321,8 @@ def schedule_clustering_after_write(
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
|
||||
@@ -350,9 +350,6 @@ class MemoryAgentService:
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
print(100 * '-')
|
||||
print(langchain_messages)
|
||||
print(100 * '-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
|
||||
@@ -2760,7 +2760,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
patch_fail = 0
|
||||
for cid in incomplete_ids:
|
||||
try:
|
||||
await engine._generate_community_metadata(cid, end_user_id)
|
||||
await engine._generate_community_metadata([cid], end_user_id)
|
||||
patch_ok += 1
|
||||
except Exception as patch_err:
|
||||
patch_fail += 1
|
||||
|
||||
Reference in New Issue
Block a user