Merge #13 into develop from fix/stream-output
'fix/stream-output' * fix/stream-output: (17 commits squashed) - [fix]Fix the issue where the streaming output effect is not obvious. - [fix]Fix the issue where the streaming output effect is not obvious. - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix] - [fix]Skip time extraction - [fix] - [fix]Skip time extraction - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix]Remove human-induced delays - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Remove human-induced delays - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com> Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/13
This commit is contained in:
@@ -35,6 +35,7 @@ from app.core.memory.models.graph_models import (
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
@@ -52,6 +53,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.tem
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
embedding_generation_all,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
@@ -177,12 +179,24 @@ class ExtractionOrchestrator:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
# 🔥 陈述句提取完成后,立即发送知识抽取完成消息
|
||||
if self.progress_callback:
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": 0, # 暂时为0,后续会更新
|
||||
"triplets_count": 0, # 暂时为0,后续会更新
|
||||
"temporal_ranges_count": 0, # 暂时为0,后续会更新
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行)
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)")
|
||||
(
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -211,7 +225,6 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
temporal_maps,
|
||||
triplet_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -539,108 +552,9 @@ class ExtractionOrchestrator:
|
||||
|
||||
return temporal_maps
|
||||
|
||||
async def _extract_emotions(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
情绪信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
||||
|
||||
# 收集所有陈述句及其配置
|
||||
all_statements = []
|
||||
statement_metadata = [] # (dialog_idx, statement_id)
|
||||
|
||||
# 获取第一个对话的config_id来加载配置
|
||||
config_id = None
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 收集所有陈述句
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
statement, config = stmt_data
|
||||
try:
|
||||
return await emotion_service.extract_emotion(statement.statement, config)
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}")
|
||||
return None
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
emotion_maps = [{} for _ in dialog_data_list]
|
||||
successful_extractions = 0
|
||||
|
||||
for i, result in enumerate(results):
|
||||
d_idx, stmt_id = statement_metadata[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"陈述句处理异常: {result}")
|
||||
emotion_maps[d_idx][stmt_id] = None
|
||||
else:
|
||||
emotion_maps[d_idx][stmt_id] = result
|
||||
if result is not None:
|
||||
successful_extractions += 1
|
||||
|
||||
# 统计提取结果
|
||||
logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪")
|
||||
|
||||
return emotion_maps
|
||||
|
||||
async def _parallel_extract_and_embed(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, List[float]]],
|
||||
@@ -648,39 +562,35 @@ class ExtractionOrchestrator:
|
||||
List[List[float]],
|
||||
]:
|
||||
"""
|
||||
并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
|
||||
这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
- 三元组提取:从陈述句中提取实体和关系
|
||||
- 时间信息提取:从陈述句中提取时间范围
|
||||
- 情绪提取:从陈述句中提取情绪信息
|
||||
- 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
六个列表的元组:
|
||||
五个列表的元组:
|
||||
- 三元组映射列表
|
||||
- 时间信息映射列表
|
||||
- 情绪映射列表
|
||||
- 陈述句嵌入映射列表
|
||||
- 分块嵌入映射列表
|
||||
- 对话嵌入列表
|
||||
"""
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成")
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成")
|
||||
|
||||
# 创建四个并行任务
|
||||
# 创建三个并行任务
|
||||
triplet_task = self._extract_triplets(dialog_data_list)
|
||||
temporal_task = self._extract_temporal(dialog_data_list)
|
||||
emotion_task = self._extract_emotions(dialog_data_list)
|
||||
embedding_task = self._generate_basic_embeddings(dialog_data_list)
|
||||
|
||||
# 并行执行
|
||||
results = await asyncio.gather(
|
||||
triplet_task,
|
||||
temporal_task,
|
||||
emotion_task,
|
||||
embedding_task,
|
||||
return_exceptions=True
|
||||
)
|
||||
@@ -688,21 +598,19 @@ class ExtractionOrchestrator:
|
||||
# 解包结果
|
||||
triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list]
|
||||
temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list]
|
||||
emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list]
|
||||
|
||||
if isinstance(results[3], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[3]}")
|
||||
if isinstance(results[2], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[2]}")
|
||||
statement_embedding_maps = [{} for _ in dialog_data_list]
|
||||
chunk_embedding_maps = [{} for _ in dialog_data_list]
|
||||
dialog_embeddings = [[] for _ in dialog_data_list]
|
||||
else:
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3]
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2]
|
||||
|
||||
logger.info("并行任务执行完成")
|
||||
return (
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -819,7 +727,6 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: List[DialogData],
|
||||
temporal_maps: List[Dict[str, Any]],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
emotion_maps: List[Dict[str, Any]],
|
||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||
dialog_embeddings: List[List[float]],
|
||||
@@ -831,7 +738,6 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: 对话数据列表
|
||||
temporal_maps: 时间信息映射列表
|
||||
triplet_maps: 三元组映射列表
|
||||
emotion_maps: 情绪信息映射列表
|
||||
statement_embedding_maps: 陈述句嵌入映射列表
|
||||
chunk_embedding_maps: 分块嵌入映射列表
|
||||
dialog_embeddings: 对话嵌入列表
|
||||
@@ -846,7 +752,6 @@ class ExtractionOrchestrator:
|
||||
if (
|
||||
len(temporal_maps) != expected_length
|
||||
or len(triplet_maps) != expected_length
|
||||
or len(emotion_maps) != expected_length
|
||||
or len(statement_embedding_maps) != expected_length
|
||||
or len(chunk_embedding_maps) != expected_length
|
||||
or len(dialog_embeddings) != expected_length
|
||||
@@ -854,7 +759,6 @@ class ExtractionOrchestrator:
|
||||
logger.warning(
|
||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||
f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, "
|
||||
f"情绪映射: {len(emotion_maps)}, "
|
||||
f"陈述句嵌入: {len(statement_embedding_maps)}, "
|
||||
f"分块嵌入: {len(chunk_embedding_maps)}, "
|
||||
f"对话嵌入: {len(dialog_embeddings)}"
|
||||
@@ -863,7 +767,6 @@ class ExtractionOrchestrator:
|
||||
total_statements = 0
|
||||
assigned_temporal = 0
|
||||
assigned_triplets = 0
|
||||
assigned_emotions = 0
|
||||
assigned_statement_embeddings = 0
|
||||
assigned_chunk_embeddings = 0
|
||||
assigned_dialog_embeddings = 0
|
||||
@@ -871,13 +774,12 @@ class ExtractionOrchestrator:
|
||||
# 处理每个对话
|
||||
for i, dialog_data in enumerate(dialog_data_list):
|
||||
# 检查是否有缺失的数据
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps):
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps):
|
||||
logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值")
|
||||
continue
|
||||
|
||||
temporal_map = temporal_maps[i]
|
||||
triplet_map = triplet_maps[i]
|
||||
emotion_map = emotion_maps[i]
|
||||
statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {}
|
||||
chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {}
|
||||
dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else []
|
||||
@@ -908,18 +810,6 @@ class ExtractionOrchestrator:
|
||||
statement.triplet_extraction_info = triplet_map[statement.id]
|
||||
assigned_triplets += 1
|
||||
|
||||
# 赋值情绪信息
|
||||
if statement.id in emotion_map:
|
||||
emotion_data = emotion_map[statement.id]
|
||||
if emotion_data is not None:
|
||||
# 将EmotionExtraction对象的字段赋值到Statement
|
||||
statement.emotion_type = emotion_data.emotion_type
|
||||
statement.emotion_intensity = emotion_data.emotion_intensity
|
||||
statement.emotion_keywords = emotion_data.emotion_keywords
|
||||
statement.emotion_subject = emotion_data.emotion_subject
|
||||
statement.emotion_target = emotion_data.emotion_target
|
||||
assigned_emotions += 1
|
||||
|
||||
# 赋值陈述句嵌入
|
||||
if statement.id in statement_embedding_map:
|
||||
statement.statement_embedding = statement_embedding_map[statement.id]
|
||||
@@ -928,7 +818,6 @@ class ExtractionOrchestrator:
|
||||
logger.info(
|
||||
f"数据赋值完成 - 总陈述句: {total_statements}, "
|
||||
f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, "
|
||||
f"情绪信息: {assigned_emotions}, "
|
||||
f"陈述句嵌入: {assigned_statement_embeddings}, "
|
||||
f"分块嵌入: {assigned_chunk_embeddings}, "
|
||||
f"对话嵌入: {assigned_dialog_embeddings}"
|
||||
@@ -1038,12 +927,6 @@ class ExtractionOrchestrator:
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
# Emotion fields
|
||||
emotion_type=getattr(statement, 'emotion_type', None),
|
||||
emotion_intensity=getattr(statement, 'emotion_intensity', None),
|
||||
emotion_keywords=getattr(statement, 'emotion_keywords', None),
|
||||
emotion_subject=getattr(statement, 'emotion_subject', None),
|
||||
emotion_target=getattr(statement, 'emotion_target', None),
|
||||
)
|
||||
statement_nodes.append(statement_node)
|
||||
|
||||
@@ -1450,7 +1333,7 @@ class ExtractionOrchestrator:
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
match.group(3).strip()
|
||||
entity2_name = match.group(3).strip()
|
||||
entity2_type = match.group(4)
|
||||
|
||||
# 提取置信度和原因
|
||||
@@ -1763,6 +1646,7 @@ async def get_chunked_dialogs(
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
|
||||
# 加载测试数据
|
||||
testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json")
|
||||
@@ -1938,6 +1822,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
import os
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
|
||||
Reference in New Issue
Block a user