[MODIFY] MEM SEE OUTPUT

This commit is contained in:
Mark
2025-12-15 20:50:15 +08:00
parent 7bbef35b7d
commit 9b8db9a001
15 changed files with 863 additions and 144 deletions

View File

@@ -1,8 +1,9 @@
from typing import Optional
from typing import Optional, Union
import os
import uuid
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, Query, UploadFile
from fastapi import APIRouter, Depends, UploadFile
from fastapi.responses import StreamingResponse
from app.db import get_db
@@ -322,36 +323,24 @@ def read_all_config(
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
@router.post("/pilot_run", response_model=None)
async def pilot_run(
payload: ConfigPilotRun,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> StreamingResponse:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
# 先尝试从数据库加载配置
try:
config_loaded = reload_configuration_from_database(str(payload.config_id))
if not config_loaded:
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
except Exception as e:
api_logger.error(f"Exception while loading configuration: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
try:
svc = DataConfigService(db)
result = await svc.pilot_run(payload)
return success(data=result, msg="试运行完成")
except ValueError as e:
# 捕获参数验证错误
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
except Exception as e:
api_logger.error(f"Pilot run failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。

View File

@@ -21,7 +21,7 @@ def generate_api_key(key_type: ApiKeyType) -> str:
# 前缀映射
prefix_map = {
ApiKeyType.AGENT: "sk-agent-",
ApiKeyType.CLUSTER: "sk-cluster-",
ApiKeyType.CLUSTER: "sk-multi_agent-",
ApiKeyType.WORKFLOW: "sk-workflow-",
ApiKeyType.SERVICE: "sk-service-"
}

View File

@@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)

View File

@@ -1,5 +1,5 @@
{
"selections": {
"config_id": "1"
"config_id": ""
}
}

View File

@@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional
from typing import Optional, Callable, Awaitable
from dotenv import load_dotenv
# 导入重构后的模块
@@ -50,7 +50,11 @@ logger = get_memory_logger(__name__)
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
async def main(
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本
@@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
progress_callback: 可选的进度回调函数
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
- 在管线关键点调用以报告进度和结果数据
工作流程:
1. 初始化客户端和配置
@@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
@@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
@@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
@@ -170,6 +209,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
@@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
@@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
logger.info("Running extraction pipeline...")
step_start = time.time()
# 进度回调:正在知识抽取
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
@@ -216,6 +282,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 进度回调:生成结果
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 保存结果或输出结果
if is_pilot_run:

View File

@@ -2,7 +2,7 @@
去重功能函数
"""
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
@@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges(
report_append: bool = False,
report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
) -> Tuple[
List[ExtractedEntityNode],
List[StatementEntityEdge],
List[EntityEntityEdge],
Dict[str, Any] # 新增:返回详细的去重消歧记录
]:
"""
主流程依次执行精确匹配、模糊匹配与可选LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
返回:去重后的实体、语句→实体边、实体↔实体边。
@@ -981,8 +986,18 @@ async def deduplicate_entities_and_edges(
append=report_append,
stage_notes=report_stage_notes,
)
# 构建详细的去重消歧记录(用于内存访问,避免解析日志文件)
dedup_details = {
"exact_merge_map": exact_merge_map,
"fuzzy_merge_records": fuzzy_merge_records,
"llm_decision_records": local_llm_records,
"disamb_records": disamb_records,
"id_redirect": id_redirect,
"blocked_pairs": blocked_pairs,
}
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
def _write_dedup_fusion_report(

View File

@@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge],
List[StatementEntityEdge],
List[EntityEntityEdge],
dict, # 新增:返回去重详情
]:
"""
执行两层实体去重与融合:
@@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return(
break
# 第一层去重消歧
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
@@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return(
statement_chunk_edges,
fused_statement_entity_edges,
fused_entity_entity_edges,
dedup_details, # 返回去重详情
)

View File

@@ -12,13 +12,14 @@
5. 提供错误处理和日志记录
6. 支持试运行模式(不写入数据库)
作者:Memory Refactoring Team
作者:
日期2025-11-21
"""
import asyncio
import logging
from typing import List, Dict, Any, Tuple, Optional
import os
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
from datetime import datetime
from app.core.memory.models.message_models import DialogData
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
embedder_client: OpenAIEmbedderClient,
connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
):
"""
初始化流水线编排器
@@ -103,12 +105,21 @@ class ExtractionOrchestrator:
embedder_client: 嵌入模型客户端
connector: Neo4j 连接器
config: 流水线配置,如果为 None 则使用默认配置
progress_callback: 进度回调函数
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
- 在管线关键点调用以报告进度和结果数据
"""
self.llm_client = llm_client
self.embedder_client = embedder_client
self.connector = connector
self.config = config or ExtractionPipelineConfig()
self.is_pilot_run = False # 默认非试运行模式
self.progress_callback = progress_callback # 保存进度回调函数
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
self.id_redirect_map: Dict[str, str] = {} # ID重定向映射
# 初始化各个提取器
self.statement_extractor = StatementExtractor(
@@ -160,6 +171,13 @@ class ExtractionOrchestrator:
# 步骤 1: 陈述句提取
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
dialog_data_list = await self._extract_statements(dialog_data_list)
# 收集陈述句内容和统计数量
all_statements_list = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements)
total_statements = len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
@@ -170,11 +188,90 @@ class ExtractionOrchestrator:
chunk_embedding_maps,
dialog_embeddings,
) = await self._parallel_extract_and_embed(dialog_data_list)
# 收集实体和三元组内容,并统计数量
all_entities_list = []
all_triplets_list = []
for triplet_map in triplet_maps:
for triplet_info in triplet_map.values():
if triplet_info:
all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets)
total_entities = len(all_entities_list)
total_triplets = len(all_triplets_list)
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
# 进度回调:按三个阶段分别输出知识抽取结果
if self.progress_callback:
# 第一阶段:陈述句提取结果
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
stmt_result = {
"extraction_type": "statement",
"statement_index": i + 1,
"statement": stmt.statement,
"statement_id": stmt.id
}
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
# 第二阶段:三元组提取结果
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
triplet_result = {
"extraction_type": "triplet",
"triplet_index": i + 1,
"subject": triplet.subject_name,
"predicate": triplet.predicate,
"object": triplet.object_name
}
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
# 第三阶段:时间提取结果
if total_temporal > 0:
# 收集时间信息
temporal_results = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
temporal_results.append({
"statement_id": statement.id,
"statement": statement.statement,
"valid_at": statement.temporal_validity.valid_at,
"invalid_at": statement.temporal_validity.invalid_at
})
# 输出时间提取结果
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
time_result = {
"extraction_type": "temporal",
"temporal_index": i + 1,
"statement": temporal_result["statement"],
"valid_at": temporal_result["valid_at"],
"invalid_at": temporal_result["invalid_at"]
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
else:
# 如果没有时间信息,也发送一个时间提取完成的消息
time_result = {
"extraction_type": "temporal",
"temporal_index": 0,
"message": "未发现时间信息"
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
# 进度回调:知识抽取完成,传递知识抽取的统计信息
extraction_stats = {
"statements_count": total_statements,
"entities_count": total_entities,
"triplets_count": total_triplets,
"temporal_ranges_count": total_temporal,
}
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
# 步骤 4: 将提取的数据赋值到语句
logger.info("步骤 4/6: 数据赋值")
dialog_data_list = await self._assign_extracted_data(
@@ -218,6 +315,8 @@ class ExtractionOrchestrator:
dialog_data_list,
)
logger.info(f"知识提取流水线运行完成({mode_str}")
return result
@@ -732,6 +831,10 @@ class ExtractionOrchestrator:
包含所有节点和边的元组
"""
logger.info("开始创建节点和边")
# 进度回调:正在创建节点和边
if self.progress_callback:
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
dialogue_nodes = []
chunk_nodes = []
@@ -904,6 +1007,23 @@ class ExtractionOrchestrator:
f"陈述句-实体边: {len(statement_entity_edges)}, "
f"实体-实体边: {len(entity_entity_edges)}"
)
# 进度回调:只输出关系创建结果
if self.progress_callback:
# 输出关系创建结果
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
# 进度回调:创建节点和边完成,传递结果统计
nodes_edges_stats = {
"dialogue_nodes_count": len(dialogue_nodes),
"chunk_nodes_count": len(chunk_nodes),
"statement_nodes_count": len(statement_nodes),
"entity_nodes_count": len(entity_nodes),
"statement_chunk_edges_count": len(statement_chunk_edges),
"statement_entity_edges_count": len(statement_entity_edges),
"entity_entity_edges_count": len(entity_entity_edges),
}
await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
return (
dialogue_nodes,
@@ -950,6 +1070,11 @@ class ExtractionOrchestrator:
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
"""
logger.info("开始两阶段实体去重和消歧")
# 进度回调:正在去重消歧
if self.progress_callback:
await self.progress_callback("deduplication", "正在去重消歧...")
logger.info(
f"去重前: {len(entity_nodes)} 个实体节点, "
f"{len(statement_entity_edges)} 条陈述句-实体边, "
@@ -963,7 +1088,7 @@ class ExtractionOrchestrator:
# 只执行第一层去重
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
@@ -972,6 +1097,9 @@ class ExtractionOrchestrator:
dedup_config=self.config.deduplication,
)
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
result_tuple = (
dialogue_nodes,
chunk_nodes,
@@ -1009,7 +1137,11 @@ class ExtractionOrchestrator:
_,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = result_tuple
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
logger.info(
f"去重后: {len(final_entity_nodes)} 个实体节点, "
@@ -1021,6 +1153,46 @@ class ExtractionOrchestrator:
f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, "
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
)
# 进度回调:输出去重消歧的具体结果
if self.progress_callback:
# 分析实体合并情况
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
# 输出去重合并的实体示例
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
dedup_result = {
"result_type": "entity_merge",
"merged_entity_name": merge_detail["main_entity_name"],
"merged_count": merge_detail["merged_count"],
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
}
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
# 分析实体消歧情况
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
# 输出实体消歧的结果
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
disamb_result = {
"result_type": "entity_disambiguation",
"disambiguated_entity_name": disamb_detail["entity_name"],
"disambiguation_type": disamb_detail["disamb_type"],
"confidence": disamb_detail.get("confidence", "unknown"),
"reason": disamb_detail.get("reason", ""),
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
}
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
await self._send_dedup_progress_callback(
len(entity_nodes), len(final_entity_nodes),
len(statement_entity_edges), len(final_statement_entity_edges),
len(entity_entity_edges), len(final_entity_entity_edges)
)
# 写入提取结果汇总(试运行和正式模式都需要生成)
try:
@@ -1041,6 +1213,378 @@ class ExtractionOrchestrator:
logger.error(f"两阶段去重失败: {e}", exc_info=True)
raise
def _save_dedup_details(
self,
dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
):
"""
保存去重消歧的详细记录到实例变量(基于内存数据结构)
Args:
dedup_details: 去重函数返回的详细记录
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
"""
try:
# 保存ID重定向映射
self.id_redirect_map = dedup_details.get("id_redirect", {})
# 处理精确匹配的合并记录
exact_merge_map = dedup_details.get("exact_merge_map", {})
for key, info in exact_merge_map.items():
merged_ids = info.get("merged_ids", set())
if merged_ids:
self.dedup_merge_records.append({
"type": "精确匹配",
"canonical_id": info.get("canonical_id"),
"entity_name": info.get("name"),
"entity_type": info.get("entity_type"),
"merged_count": len(merged_ids),
"merged_ids": list(merged_ids)
})
# 处理模糊匹配的合并记录
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
for record in fuzzy_merge_records:
# 解析模糊匹配记录字符串
# 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..."
try:
import re
match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record)
if match:
self.dedup_merge_records.append({
"type": "模糊匹配",
"canonical_id": match.group(1),
"entity_name": match.group(3),
"entity_type": match.group(4),
"merged_count": 1,
"merged_ids": [match.group(5)]
})
except Exception as e:
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
# 处理LLM去重的合并记录
llm_decision_records = dedup_details.get("llm_decision_records", [])
for record in llm_decision_records:
if "[LLM去重]" in str(record):
try:
import re
# 格式: "[LLM去重] 同名类型相似 name1type1|name2type2 | conf=0.xx | reason=..."
match = re.search(r"同名类型相似 ([^]+)([^]+)\|([^]+)([^]+)", record)
if match:
self.dedup_merge_records.append({
"type": "LLM去重",
"entity_name": match.group(1),
"entity_type": f"{match.group(2)}|{match.group(4)}",
"merged_count": 1,
"merged_ids": []
})
except Exception as e:
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
# 处理消歧记录
disamb_records = dedup_details.get("disamb_records", [])
for record in disamb_records:
if "[DISAMB阻断]" in str(record):
try:
import re
# 格式: "[DISAMB阻断] name1type1|name2type2 | conf=0.xx | reason=..."
content = str(record).replace("[DISAMB阻断]", "").strip()
match = re.search(r"([^]+)([^]+)\|([^]+)([^]+)", content)
if match:
entity1_name = match.group(1).strip()
entity1_type = match.group(2)
entity2_name = match.group(3).strip()
entity2_type = match.group(4)
# 提取置信度和原因
conf_match = re.search(r"conf=([0-9.]+)", str(record))
confidence = conf_match.group(1) if conf_match else "unknown"
reason_match = re.search(r"reason=([^|]+)", str(record))
reason = reason_match.group(1).strip() if reason_match else ""
self.dedup_disamb_records.append({
"entity_name": entity1_name,
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
"confidence": confidence,
"reason": reason[:100] + "..." if len(reason) > 100 else reason
})
except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges(
self,
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
Args:
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
Returns:
合并详情列表,每个元素包含主实体名称和合并数量
"""
try:
# 直接使用保存的合并记录
if self.dedup_merge_records:
# 按合并数量排序,返回前几个
sorted_records = sorted(
self.dedup_merge_records,
key=lambda x: x.get("merged_count", 0),
reverse=True
)
merge_info = []
for record in sorted_records:
merge_info.append({
"main_entity_name": record.get("entity_name", "未知实体"),
"merged_count": record.get("merged_count", 1)
})
return merge_info
# 如果没有保存的记录,返回空列表
logger.info("未找到实体合并记录")
return []
except Exception as e:
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
return []
async def _analyze_entity_disambiguation(
self,
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
Args:
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
Returns:
消歧详情列表,每个元素包含实体名称和消歧类型
"""
try:
# 直接使用保存的消歧记录
if self.dedup_disamb_records:
return self.dedup_disamb_records
# 如果没有保存的记录,返回空列表
logger.info("未找到实体消歧记录")
return []
except Exception as e:
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
return []
def _get_entity_type_display_name(self, entity_type: str) -> str:
"""
获取实体类型的中文显示名称
Args:
entity_type: 英文实体类型
Returns:
中文显示名称
"""
type_mapping = {
"Person": "人物实体节点",
"Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
"Event": "事件实体节点",
"Concept": "概念实体节点",
"Time": "时间实体节点",
"Position": "职位实体节点",
"WorkRole": "职业实体节点",
"System": "系统实体节点",
"Policy": "政策实体节点",
"HistoricalPeriod": "历史时期实体节点",
"HistoricalState": "历史国家实体节点",
"HistoricalEvent": "历史事件实体节点",
"EconomicFactor": "经济因素实体节点",
"Condition": "条件实体节点",
"Numeric": "数值实体节点"
}
return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results(
self,
entity_entity_edges: List[EntityEntityEdge],
entity_nodes: List[ExtractedEntityNode]
):
"""
输出关系创建结果
Args:
entity_entity_edges: 实体-实体边列表
entity_nodes: 实体节点列表
"""
try:
# 创建实体ID到名称的映射
entity_id_to_name = {node.id: node.name for node in entity_nodes}
# 输出关系创建结果
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
relation_type = edge.relation_type
relationship_result = {
"result_type": "relationship_creation",
"relationship_index": i + 1,
"source_entity": source_name,
"relation_type": relation_type,
"target_entity": target_name,
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
}
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
except Exception as e:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback(
self,
original_entities: int,
final_entities: int,
original_stmt_edges: int,
final_stmt_edges: int,
original_ent_edges: int,
final_ent_edges: int,
):
"""
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
Args:
original_entities: 去重前实体数量
final_entities: 去重后实体数量
original_stmt_edges: 去重前陈述句-实体边数量
final_stmt_edges: 去重后陈述句-实体边数量
original_ent_edges: 去重前实体-实体边数量
final_ent_edges: 去重后实体-实体边数量
"""
try:
# 解析去重消歧报告文件,获取具体的去重和消歧效果
dedup_details = await self._parse_dedup_report()
# 计算去重效果统计
entities_reduced = original_entities - final_entities
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
ent_edges_reduced = original_ent_edges - final_ent_edges
# 构建进度回调数据
dedup_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": entities_reduced,
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
},
"statement_entity_edges": {
"original_count": original_stmt_edges,
"final_count": final_stmt_edges,
"reduced_count": stmt_edges_reduced,
},
"entity_entity_edges": {
"original_count": original_ent_edges,
"final_count": final_ent_edges,
"reduced_count": ent_edges_reduced,
},
"dedup_examples": dedup_details.get("dedup_examples", []),
"disamb_examples": dedup_details.get("disamb_examples", []),
"summary": {
"total_merges": dedup_details.get("total_merges", 0),
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
}
}
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
except Exception as e:
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
# 即使解析失败,也发送基本的统计信息
try:
basic_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": original_entities - final_entities,
},
"summary": f"实体去重合并{original_entities - final_entities}"
}
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
except Exception as e2:
logger.error(f"发送基本去重统计失败: {e2}", exc_info=True)
async def _parse_dedup_report(self) -> Dict[str, Any]:
"""
获取去重消歧报告,直接使用内存中的记录(不再解析日志文件)
Returns:
包含去重和消歧详细信息的字典
"""
try:
# 直接使用保存的记录构建报告
dedup_examples = []
disamb_examples = []
total_merges = 0
total_disambiguations = 0
# 处理合并记录
for record in self.dedup_merge_records:
merge_count = record.get("merged_count", 0)
total_merges += merge_count
dedup_examples.append({
"type": record.get("type", "未知"),
"entity_name": record.get("entity_name", "未知实体"),
"entity_type": record.get("entity_type", "未知类型"),
"merge_count": merge_count,
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}"
})
# 处理消歧记录
for record in self.dedup_disamb_records:
total_disambiguations += 1
# 从消歧类型中提取实体类型信息
disamb_type = record.get("disamb_type", "")
entity_name = record.get("entity_name", "未知实体")
disamb_examples.append({
"entity1_name": entity_name,
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功"
})
return {
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
"total_merges": total_merges,
"total_disambiguations": total_disambiguations,
}
except Exception as e:
logger.error(f"获取去重报告失败: {e}", exc_info=True)
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
# ============================================================================
# 数据加载和预处理函数

View File

@@ -13,7 +13,7 @@ from app.db import Base
class ApiKeyType(StrEnum):
"""API Key 类型"""
AGENT = "agent" # 智能体
CLUSTER = "cluster" # 集群
CLUSTER = "multi_agent" # 集群
WORKFLOW = "workflow" # 工作流
SERVICE = "service" # 服务

View File

@@ -61,7 +61,7 @@ class ModelConfig(Base):
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
# 关联关系
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")

View File

@@ -126,6 +126,7 @@ class ApiKeyRepository:
"quota_used": api_key.quota_used,
"quota_limit": api_key.quota_limit,
"last_used_at": api_key.last_used_at,
"rate_limit": api_key.rate_limit,
"avg_response_time": float(avg_response_time) if avg_response_time else None
}

View File

@@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel):
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -60,6 +61,7 @@ class ConflictSchema(BaseModel):
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_resolved(cls, v):
if isinstance(v, dict):
conflict = v.get("conflict")
@@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型
def _now_ms() -> int:
return int(round(time.time() * 1000))
return round(time.time() * 1000)
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:

View File

@@ -43,12 +43,13 @@ class ApiKeyService:
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.resource_id == data.resource_id,
ApiKey.name == data.name,
ApiKey.is_active
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 生成 API Key
api_key = generate_api_key(data.type)
@@ -137,21 +138,19 @@ class ApiKeyService:
"""更新 API Key配置"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
# 检查名称重复
if data.name and data.name != api_key.name:
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.resource_id == data.resource_id,
ApiKey.name == data.name,
ApiKey.is_active,
ApiKey.id != api_key_id
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)
@@ -170,9 +169,6 @@ class ApiKeyService:
"""删除 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
ApiKeyRepository.delete(db, api_key_id)
db.commit()
@@ -188,9 +184,6 @@ class ApiKeyService:
"""重新生成 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
# 检查 API Key 是否激活
if not api_key.is_active:
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
@@ -217,9 +210,6 @@ class ApiKeyService:
"""获取使用统计"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
return api_key_schema.ApiKeyStats(**stats_data)
@@ -236,9 +226,6 @@ class ApiKeyService:
# 验证 API Key 权限
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
items, total = ApiKeyLogRepository.list_by_api_key(
db, api_key_id, filters, page, pagesize
)

View File

@@ -4,9 +4,12 @@ Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, AsyncGenerator
import os
import json
import asyncio
import time
from datetime import datetime
from sqlalchemy.orm import Session
from dotenv import load_dotenv
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
from app.models.user_model import User
from app.models.end_user_model import EndUser
from app.core.logging_config import get_logger
from app.utils.sse_utils import format_sse_message
from app.schemas.memory_storage_schema import (
ConfigFilter,
ConfigPilotRun,
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return self._convert_timestamps_to_format(data_list)
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
"""
选择策略与内存覆写与同步版保持一致:优先 payload.config_id其次 dbrun.json两者皆无时报错。
支持 dialogue_text 参数用于试运行模式。
流式执行试运行,产生 SSE 格式的进度事件
Args:
payload: 试运行配置和对话文本
Yields:
SSE 格式的字符串,包含以下事件类型:
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
- result: 最终结果
- error: 错误信息
- done: 完成标记
Raises:
ValueError: 当配置无效或参数缺失时
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
try:
# 发出初始进度事件
yield format_sse_message("starting", {
"message": "开始试运行...",
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(复用现有逻辑)
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# 应用内存覆写并刷新常量(在导入主管线前)
# 注意:仅在内存中覆写配置,不修改 runtime.json 文
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# 导入并 await 主管线(使用当前 ASGI 事件循环)
from app.core.memory.main import main as pipeline_main
from app.core.memory.utils.self_reflexion_utils import reflexion
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
logger.info("[PILOT_RUN] pipeline_main completed")
# 调用自我反思
# data = [
# {
# "data": {
# "id": "1",
# "statement": "张明现在在谷歌工作。",
# "group_id": "1",
# "chunk_id": "10",
# "created_at": "2023-01-01",
# "expired_at": "2023-01-02",
# "valid_at": "2023-01-01",
# "invalid_at": "2023-01-02",
# "entity_ids": []
# },
# "conflict": True,
# "conflict_memory": {
# "id": "1",
# "statement": "张明现在在清华大学当讲师。",
# "group_id": "1",
# "chunk_id": "1",
# "created_at": "2019-12-01T19:15:05.213210",
# "expired_at": None,
# "valid_at": None,
# "invalid_at": None,
# "entity_ids": []
# }
# }
# ]
from app.core.memory.utils.config.get_example_data import get_example_data
data = get_example_data()
reflexion_result = await reflexion(data)
# 读取输出,使用全局配置路径
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
return {
"config_id": cid,
"time_log": os.path.join(project_root, "time.log"),
"extracted_result": extracted_result,
}
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事
progress_queue: asyncio.Queue = asyncio.Queue()
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
"""
进度回调函数,将进度事件放入队列
Args:
stage: 阶段标识
message: 进度消息
data: 可选的结果数据(用于传递节点执行结果)
"""
await progress_queue.put((stage, message, data))
# 步骤 3: 在后台任务中执行管线
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
# 标记管线完成
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
except Exception as e:
# 将异常放入队列
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
# 启动后台任务
pipeline_task = asyncio.create_task(run_pipeline())
# 步骤 4: 从队列中读取进度事件并发出
while True:
try:
# 等待进度事件,设置超时以检测客户端断开
stage, message, data = await asyncio.wait_for(
progress_queue.get(),
timeout=0.5
)
# 检查特殊标记
if stage == "__PIPELINE_COMPLETE__":
break
elif stage == "__PIPELINE_ERROR__":
raise RuntimeError(message)
# 构建进度事件数据
progress_data = {
"message": message,
"time": int(time.time() * 1000)
}
# 如果有结果数据,添加到事件中
if data:
progress_data["data"] = data
# 发出进度事件,使用 stage 作为事件类型
yield format_sse_message(stage, progress_data)
except TimeoutError:
# 超时,继续等待(这允许检测客户端断开)
continue
# 等待管线任务完成
await pipeline_task
# 步骤 5: 读取提取结果
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
# 步骤 6: 发出结果事件
result_data = {
"config_id": cid,
"time_log": os.path.join(project_root, "logs", "time.log"),
"extracted_result": extracted_result,
}
yield format_sse_message("result", result_data)
# 步骤 7: 发出完成事件
yield format_sse_message("done", {
"message": "试运行完成",
"time": int(time.time() * 1000)
})
except asyncio.CancelledError:
# 客户端断开连接
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
raise
except Exception as e:
# 发出错误事件
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
yield format_sse_message("error", {
"code": 5000,
"message": "试运行失败",
"error": str(e),
"time": int(time.time() * 1000)
})
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------

View File

@@ -0,0 +1,27 @@
"""
Server-Sent Events (SSE) Utility Functions
Provides shared utilities for formatting and handling SSE messages.
"""
import json
from typing import Dict, Any
def format_sse_message(event_type: str, data: Dict[str, Any]) -> str:
"""
Format a message in Server-Sent Events (SSE) format.
Args:
event_type: Type of event (stage name, result, error, done)
data: Event data dictionary to be serialized as JSON
Returns:
SSE formatted string: "event: <type>\\ndata: <json>\\n\\n"
Example:
>>> format_sse_message("loading", {"message": "Loading..."})
'event: loading\\ndata: {"message": "Loading..."}\\n\\n'
"""
json_data = json.dumps(data, ensure_ascii=False)
return f"event: {event_type}\ndata: {json_data}\n\n"