【change】 1.Standardize log specifications;2.Cluster settings trigger explicitly
This commit is contained in:
@@ -77,6 +77,7 @@ celery_app.conf.update(
|
|||||||
|
|
||||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
|
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存1小时
|
result_expires=3600, # 结果保存1小时
|
||||||
|
|||||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
|||||||
# Fallback to console only if file write fails
|
# Fallback to console only if file write fails
|
||||||
print(f"Warning: Could not write to timing log: {e}")
|
print(f"Warning: Could not write to timing log: {e}")
|
||||||
|
|
||||||
# Always print to console (backward compatible behavior)
|
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||||
print(f"✓ {step_name}: {duration:.2f}s")
|
_timing_logger = logging.getLogger(__name__)
|
||||||
|
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
def get_agent_logger(name: str = "agent_service",
|
def get_agent_logger(name: str = "agent_service",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
@@ -169,8 +169,8 @@ async def write(
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
||||||
schedule_clustering_after_write(
|
await _trigger_clustering_sync(
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
|
|||||||
@@ -82,16 +82,26 @@ class OpenAIClient(LLMClient):
|
|||||||
LLMClientException: LLM 调用失败
|
LLMClientException: LLM 调用失败
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
template = """{messages}"""
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||||
prompt = ChatPromptTemplate.from_template(template)
|
|
||||||
chain = prompt | self.client
|
# 将 dict 消息列表转换为 LangChain 消息对象
|
||||||
|
lc_messages = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role", "user")
|
||||||
|
content = m.get("content", "")
|
||||||
|
if role == "system":
|
||||||
|
lc_messages.append(SystemMessage(content=content))
|
||||||
|
elif role == "assistant":
|
||||||
|
lc_messages.append(AIMessage(content=content))
|
||||||
|
else:
|
||||||
|
lc_messages.append(HumanMessage(content=content))
|
||||||
|
|
||||||
# 添加 Langfuse 回调(如果可用)
|
# 添加 Langfuse 回调(如果可用)
|
||||||
config = {}
|
config = {}
|
||||||
if self.langfuse_handler:
|
if self.langfuse_handler:
|
||||||
config["callbacks"] = [self.langfuse_handler]
|
config["callbacks"] = [self.langfuse_handler]
|
||||||
|
|
||||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
response = await self.client.ainvoke(lc_messages, config=config)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -71,13 +71,11 @@ class LabelPropagationEngine:
|
|||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
self.embedding_model_id = embedding_model_id
|
self.embedding_model_id = embedding_model_id
|
||||||
self.embedding_model_id = embedding_model_id
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
|
|||||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||||
|
await self._generate_community_metadata([new_cid], end_user_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 统计邻居社区分布
|
# 统计邻居社区分布
|
||||||
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata([target_cid], end_user_id)
|
# 新实体加入后成员变化,强制重新生成元数据
|
||||||
|
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
async def _generate_community_metadata(
|
async def _generate_community_metadata(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为一个或多个社区生成并写入元数据。
|
为一个或多个社区生成并写入元数据。
|
||||||
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
|
|||||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||||
2. 收集所有 summary,一次性批量 embed
|
2. 收集所有 summary,一次性批量 embed
|
||||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||||
"""
|
|
||||||
if not community_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||||
|
"""
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
async def _build_one(cid: str) -> Optional[Dict]:
|
||||||
async def _build_one(cid: str):
|
try:
|
||||||
members = await self.repo.get_community_members(cid, end_user_id)
|
if not force:
|
||||||
if not members:
|
check_embedding = bool(self.embedding_model_id)
|
||||||
|
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||||
|
return None
|
||||||
|
|
||||||
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
|
if not members:
|
||||||
|
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||||
|
return None
|
||||||
|
|
||||||
|
sorted_members = sorted(
|
||||||
|
members,
|
||||||
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||||
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
if self.llm_model_id:
|
||||||
|
try:
|
||||||
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
|
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||||
|
rel_lines = [
|
||||||
|
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||||
|
for r in relationships
|
||||||
|
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||||
|
]
|
||||||
|
rel_section = (
|
||||||
|
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||||
|
if rel_lines else ""
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||||
|
f"请为这组实体所代表的主题:\n"
|
||||||
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
|
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||||
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
)
|
||||||
|
with get_db_context() as db:
|
||||||
|
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
name = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
summary = line[3:].strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"community_id": cid,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"name": name,
|
||||||
|
"summary": summary,
|
||||||
|
"core_entities": core_entities,
|
||||||
|
"summary_embedding": None,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sorted_members = sorted(
|
|
||||||
members,
|
|
||||||
key=lambda m: m.get("activation_value") or 0,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
|
||||||
|
|
||||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
|
||||||
|
|
||||||
# 方案四:注入社区内实体间关系三元组
|
|
||||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
|
||||||
rel_lines = [
|
|
||||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
|
||||||
for r in relationships
|
|
||||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
|
||||||
]
|
|
||||||
rel_section = (
|
|
||||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
|
||||||
if rel_lines else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
|
||||||
f"请为这组实体所代表的主题:\n"
|
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
|
||||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
|
||||||
)
|
|
||||||
with get_db_context() as db:
|
|
||||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
|
||||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
|
|
||||||
name, summary = "", ""
|
|
||||||
for line in text.strip().splitlines():
|
|
||||||
if line.startswith("名称:"):
|
|
||||||
name = line[3:].strip()
|
|
||||||
elif line.startswith("摘要:"):
|
|
||||||
summary = line[3:].strip()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"community_id": cid,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"name": name,
|
|
||||||
"summary": summary,
|
|
||||||
"core_entities": core_entities,
|
|
||||||
"summary_embedding": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[_build_one(cid) for cid in community_ids],
|
*[_build_one(cid) for cid in community_ids],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
|
|||||||
metadata_list.append(res)
|
metadata_list.append(res)
|
||||||
|
|
||||||
if not metadata_list:
|
if not metadata_list:
|
||||||
|
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- 阶段2:批量生成 summary_embedding ---
|
# --- 阶段2:批量生成 summary_embedding ---
|
||||||
summaries = [m["summary"] for m in metadata_list]
|
if self.embedding_model_id:
|
||||||
with get_db_context() as db:
|
try:
|
||||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
summaries = [m["summary"] for m in metadata_list]
|
||||||
embeddings = await embedder.response(summaries)
|
with get_db_context() as db:
|
||||||
for i, meta in enumerate(metadata_list):
|
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
embeddings = await embedder.response(summaries)
|
||||||
|
for i, meta in enumerate(metadata_list):
|
||||||
|
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# --- 阶段3:写入(单个 or 批量)---
|
# --- 阶段3:写入(单个 or 批量)---
|
||||||
if len(metadata_list) == 1:
|
if len(metadata_list) == 1:
|
||||||
@@ -558,17 +576,13 @@ class LabelPropagationEngine:
|
|||||||
core_entities=m["core_entities"],
|
core_entities=m["core_entities"],
|
||||||
summary_embedding=m["summary_embedding"],
|
summary_embedding=m["summary_embedding"],
|
||||||
)
|
)
|
||||||
if result:
|
if not result:
|
||||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||||
else:
|
|
||||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
|
||||||
else:
|
else:
|
||||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||||
if ok:
|
if not ok:
|
||||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||||
else:
|
|
||||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
@@ -9,6 +9,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
|||||||
ScenePatterns
|
ScenePatterns
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DialogExtractionResponse(BaseModel):
|
class DialogExtractionResponse(BaseModel):
|
||||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||||
@@ -706,7 +709,7 @@ class SemanticPruner:
|
|||||||
# 阈值保护:最高0.9
|
# 阈值保护:最高0.9
|
||||||
proportion = float(self.config.pruning_threshold)
|
proportion = float(self.config.pruning_threshold)
|
||||||
if proportion > 0.9:
|
if proportion > 0.9:
|
||||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||||
proportion = 0.9
|
proportion = 0.9
|
||||||
if proportion < 0.0:
|
if proportion < 0.0:
|
||||||
proportion = 0.0
|
proportion = 0.0
|
||||||
@@ -905,7 +908,7 @@ class SemanticPruner:
|
|||||||
|
|
||||||
# Safety: avoid empty dataset
|
# Safety: avoid empty dataset
|
||||||
if not result:
|
if not result:
|
||||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||||
return dialogs
|
return dialogs
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -915,8 +918,7 @@ class SemanticPruner:
|
|||||||
try:
|
try:
|
||||||
self.run_logs.append(msg)
|
self.run_logs.append(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 任何异常都不影响打印
|
|
||||||
pass
|
pass
|
||||||
print(msg)
|
logger.debug(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
|||||||
return await self.embedder_client.response(texts)
|
return await self.embedder_client.response(texts)
|
||||||
|
|
||||||
# 分批并行处理
|
# 分批并行处理
|
||||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||||
|
|
||||||
# 并行发送所有批次
|
# 并行发送所有批次
|
||||||
batch_results = await asyncio.gather(*[
|
batch_results = await asyncio.gather(*[
|
||||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
|||||||
for batch_result in batch_results:
|
for batch_result in batch_results:
|
||||||
embeddings.extend(batch_result)
|
embeddings.extend(batch_result)
|
||||||
|
|
||||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
async def generate_statement_embeddings(
|
async def generate_statement_embeddings(
|
||||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
每个对话的陈述句嵌入向量映射列表
|
每个对话的陈述句嵌入向量映射列表
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成陈述句嵌入向量 ===")
|
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||||
|
|
||||||
# 收集所有陈述句
|
# 收集所有陈述句
|
||||||
all_statements = []
|
all_statements = []
|
||||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
|||||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||||
|
|
||||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||||
return stmt_embedding_maps
|
return stmt_embedding_maps
|
||||||
|
|
||||||
async def generate_chunk_embeddings(
|
async def generate_chunk_embeddings(
|
||||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
每个对话的分块嵌入向量映射列表
|
每个对话的分块嵌入向量映射列表
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成分块嵌入向量 ===")
|
logger.debug("=== 生成分块嵌入向量 ===")
|
||||||
|
|
||||||
# 收集所有分块
|
# 收集所有分块
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
|||||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||||
|
|
||||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||||
return chunk_embedding_maps
|
return chunk_embedding_maps
|
||||||
|
|
||||||
async def generate_dialog_embeddings(
|
async def generate_dialog_embeddings(
|
||||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成所有嵌入向量 ===")
|
logger.debug("=== 生成所有嵌入向量 ===")
|
||||||
|
|
||||||
# 并发生成陈述句和分块嵌入向量
|
# 并发生成陈述句和分块嵌入向量
|
||||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
|||||||
# 对话嵌入向量(当前跳过)
|
# 对话嵌入向量(当前跳过)
|
||||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||||
|
|
||||||
print(
|
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
|
||||||
)
|
|
||||||
|
|
||||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||||
|
|
||||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
更新后的三元组映射列表(实体包含嵌入向量)
|
更新后的三元组映射列表(实体包含嵌入向量)
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成实体嵌入向量 ===")
|
logger.debug("=== 生成实体嵌入向量 ===")
|
||||||
|
|
||||||
entity_texts: List[str] = []
|
entity_texts: List[str] = []
|
||||||
entity_refs: List[Any] = []
|
entity_refs: List[Any] = []
|
||||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
|||||||
entity_refs.append(ent)
|
entity_refs.append(ent)
|
||||||
|
|
||||||
if not entity_texts:
|
if not entity_texts:
|
||||||
print("没有找到需要生成嵌入向量的实体")
|
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
# 批量生成嵌入向量
|
# 批量生成嵌入向量
|
||||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
|||||||
|
|
||||||
# 打印前几个嵌入向量的维度
|
# 打印前几个嵌入向量的维度
|
||||||
for i in range(min(5, len(embeddings))):
|
for i in range(min(5, len(embeddings))):
|
||||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||||
|
|
||||||
# 将嵌入向量赋值给实体
|
# 将嵌入向量赋值给实体
|
||||||
for ent, emb in zip(entity_refs, embeddings):
|
for ent, emb in zip(entity_refs, embeddings):
|
||||||
setattr(ent, "name_embedding", emb)
|
setattr(ent, "name_embedding", emb)
|
||||||
|
|
||||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
|
|
||||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
|||||||
Returns:
|
Returns:
|
||||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||||
"""
|
"""
|
||||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||||
|
|
||||||
generator = EmbeddingGenerator(embedding_id)
|
generator = EmbeddingGenerator(embedding_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||||
"""Delete all nodes in the database."""
|
"""Delete all nodes in the database."""
|
||||||
@@ -217,10 +220,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
summaries=flattened
|
summaries=flattened
|
||||||
)
|
)
|
||||||
created_ids = [record.get("uuid") for record in result]
|
created_ids = [record.get("uuid") for record in result]
|
||||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||||
return created_ids
|
return created_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ class CommunityRepository:
|
|||||||
)
|
)
|
||||||
return bool(result)
|
return bool(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_community_metadata failed: {e}")
|
logger.error(f"update_community_metadata failed: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def batch_update_community_metadata(
|
async def batch_update_community_metadata(
|
||||||
|
|||||||
@@ -1069,6 +1069,7 @@ Graph_Node_query = """
|
|||||||
|
|
||||||
COMMUNITY_NODE_UPSERT = """
|
COMMUNITY_NODE_UPSERT = """
|
||||||
MERGE (c:Community {community_id: $community_id})
|
MERGE (c:Community {community_id: $community_id})
|
||||||
|
ON CREATE SET c.id = $community_id
|
||||||
SET c.end_user_id = $end_user_id,
|
SET c.end_user_id = $end_user_id,
|
||||||
c.member_count = $member_count,
|
c.member_count = $member_count,
|
||||||
c.updated_at = datetime()
|
c.updated_at = datetime()
|
||||||
@@ -1175,7 +1176,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
|||||||
|
|
||||||
UPDATE_COMMUNITY_METADATA = """
|
UPDATE_COMMUNITY_METADATA = """
|
||||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
SET c.name = $name,
|
SET c.id = coalesce(c.id, $community_id),
|
||||||
|
c.name = $name,
|
||||||
c.summary = $summary,
|
c.summary = $summary,
|
||||||
c.core_entities = $core_entities,
|
c.core_entities = $core_entities,
|
||||||
c.summary_embedding = $summary_embedding,
|
c.summary_embedding = $summary_embedding,
|
||||||
@@ -1186,7 +1188,8 @@ RETURN c.community_id AS community_id
|
|||||||
BATCH_UPDATE_COMMUNITY_METADATA = """
|
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||||
UNWIND $communities AS row
|
UNWIND $communities AS row
|
||||||
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||||
SET c.name = row.name,
|
SET c.id = coalesce(c.id, row.community_id),
|
||||||
|
c.name = row.name,
|
||||||
c.summary = row.summary,
|
c.summary = row.summary,
|
||||||
c.core_entities = row.core_entities,
|
c.core_entities = row.core_entities,
|
||||||
c.summary_embedding = row.summary_embedding,
|
c.summary_embedding = row.summary_embedding,
|
||||||
@@ -1270,6 +1273,40 @@ RETURN
|
|||||||
startNode(r) = e AS r_from_e
|
startNode(r) = e AS r_from_e
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
RETURN (
|
||||||
|
c.name IS NOT NULL AND c.name <> '' AND
|
||||||
|
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||||
|
c.core_entities IS NOT NULL
|
||||||
|
) AS is_complete
|
||||||
|
"""
|
||||||
|
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
RETURN (
|
||||||
|
c.name IS NOT NULL AND c.name <> '' AND
|
||||||
|
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||||
|
c.core_entities IS NOT NULL AND
|
||||||
|
c.summary_embedding IS NOT NULL
|
||||||
|
) AS is_complete
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_INCOMPLETE_COMMUNITIES = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||||
|
OR c.name = '' OR c.summary = ''
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
WHERE c.name IS NULL OR c.name = ''
|
||||||
|
OR c.summary IS NULL OR c.summary = ''
|
||||||
|
OR c.core_entities IS NULL
|
||||||
|
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
# Community keyword search: matches name or summary via fulltext index
|
# Community keyword search: matches name or summary via fulltext index
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||||
@@ -1325,39 +1362,4 @@ RETURN s.statement AS statement,
|
|||||||
c.name AS community_name
|
c.name AS community_name
|
||||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
|
||||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
|
||||||
RETURN (
|
|
||||||
c.name IS NOT NULL AND c.name <> '' AND
|
|
||||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
|
||||||
c.core_entities IS NOT NULL
|
|
||||||
) AS is_complete
|
|
||||||
"""
|
|
||||||
|
|
||||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
|
||||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
|
||||||
RETURN (
|
|
||||||
c.name IS NOT NULL AND c.name <> '' AND
|
|
||||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
|
||||||
c.core_entities IS NOT NULL AND
|
|
||||||
c.summary_embedding IS NOT NULL
|
|
||||||
) AS is_complete
|
|
||||||
"""
|
|
||||||
|
|
||||||
GET_INCOMPLETE_COMMUNITIES = """
|
|
||||||
MATCH (c:Community {end_user_id: $end_user_id})
|
|
||||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
|
||||||
OR c.name = '' OR c.summary = ''
|
|
||||||
RETURN c.community_id AS community_id
|
|
||||||
"""
|
|
||||||
|
|
||||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
|
||||||
MATCH (c:Community {end_user_id: $end_user_id})
|
|
||||||
WHERE c.name IS NULL OR c.name = ''
|
|
||||||
OR c.summary IS NULL OR c.summary = ''
|
|
||||||
OR c.core_entities IS NULL
|
|
||||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
|
||||||
RETURN c.community_id AS community_id
|
|
||||||
"""
|
|
||||||
@@ -162,7 +162,7 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||||
schedule_clustering_after_write() 显式触发。
|
_trigger_clustering_sync() 显式触发。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dialogue_nodes: List of DialogueNode objects to save
|
dialogue_nodes: List of DialogueNode objects to save
|
||||||
@@ -303,16 +303,13 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def schedule_clustering_after_write(
|
async def _trigger_clustering_sync(
|
||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
写入 Neo4j 成功后,调度后台聚类任务。
|
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||||
|
|
||||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
|
||||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
|
||||||
"""
|
"""
|
||||||
if not entity_nodes:
|
if not entity_nodes:
|
||||||
return
|
return
|
||||||
@@ -324,8 +321,8 @@ def schedule_clustering_after_write(
|
|||||||
|
|
||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
|
|||||||
@@ -350,9 +350,6 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
print(100 * '-')
|
|
||||||
print(langchain_messages)
|
|
||||||
print(100 * '-')
|
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
|
|||||||
@@ -2760,7 +2760,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
|||||||
patch_fail = 0
|
patch_fail = 0
|
||||||
for cid in incomplete_ids:
|
for cid in incomplete_ids:
|
||||||
try:
|
try:
|
||||||
await engine._generate_community_metadata(cid, end_user_id)
|
await engine._generate_community_metadata([cid], end_user_id)
|
||||||
patch_ok += 1
|
patch_ok += 1
|
||||||
except Exception as patch_err:
|
except Exception as patch_err:
|
||||||
patch_fail += 1
|
patch_fail += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user