From 9b8db9a0012b81aa8aec70c79f89f7a52c9f86aa Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 15 Dec 2025 20:50:15 +0800 Subject: [PATCH] [MODIFY] MEM SEE OUTPUT --- .../controllers/memory_storage_controller.py | 41 +- api/app/core/api_key_utils.py | 2 +- .../core/memory/agent/utils/write_tools.py | 2 + api/app/core/memory/dbrun.json | 2 +- api/app/core/memory/main.py | 75 ++- .../deduplication/deduped_and_disamb.py | 21 +- .../deduplication/two_stage_dedup.py | 4 +- .../extraction_orchestrator.py | 550 +++++++++++++++++- api/app/models/api_key_model.py | 2 +- api/app/models/models_model.py | 2 +- api/app/repositories/api_key_repository.py | 1 + api/app/schemas/memory_storage_schema.py | 5 +- api/app/services/api_key_service.py | 21 +- api/app/services/memory_storage_service.py | 252 +++++--- api/app/utils/sse_utils.py | 27 + 15 files changed, 863 insertions(+), 144 deletions(-) create mode 100644 api/app/utils/sse_utils.py diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index c1fe573e..89daf9ce 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Optional, Union import os import uuid from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends, Query, UploadFile +from fastapi import APIRouter, Depends, UploadFile +from fastapi.responses import StreamingResponse from app.db import get_db @@ -322,36 +323,24 @@ def read_all_config( return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e)) -@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理 +@router.post("/pilot_run", response_model=None) async def pilot_run( payload: ConfigPilotRun, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> StreamingResponse: api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}") - # 先尝试从数据库加载配置 - try: - config_loaded = reload_configuration_from_database(str(payload.config_id)) - if not config_loaded: - api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}") - return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置") - api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}") - except Exception as e: - api_logger.error(f"Exception while loading configuration: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e)) - - try: - svc = DataConfigService(db) - result = await svc.pilot_run(payload) - return success(data=result, msg="试运行完成") - except ValueError as e: - # 捕获参数验证错误 - api_logger.error(f"Pilot run parameter validation failed: {str(e)}") - return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e)) - except Exception as e: - api_logger.error(f"Pilot run failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e)) + svc = DataConfigService(db) + return StreamingResponse( + svc.pilot_run_stream(payload), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) """ 以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。 diff --git a/api/app/core/api_key_utils.py b/api/app/core/api_key_utils.py index 98ae0b10..5258f53e 100644 --- a/api/app/core/api_key_utils.py +++ b/api/app/core/api_key_utils.py @@ -21,7 +21,7 @@ def generate_api_key(key_type: ApiKeyType) -> str: # 前缀映射 prefix_map = { ApiKeyType.AGENT: "sk-agent-", - ApiKeyType.CLUSTER: "sk-cluster-", + ApiKeyType.CLUSTER: "sk-multi_agent-", ApiKeyType.WORKFLOW: "sk-workflow-", ApiKeyType.SERVICE: "sk-service-" } diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index fcb1b8a4..ebfbcc6c 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id all_statement_chunk_edges, all_statement_entity_edges, all_entity_entity_edges, + all_dedup_details, + ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) log_time("Extraction Pipeline", time.time() - step_start, log_file) diff --git a/api/app/core/memory/dbrun.json b/api/app/core/memory/dbrun.json index fdf21963..c4220a55 100644 --- a/api/app/core/memory/dbrun.json +++ b/api/app/core/memory/dbrun.json @@ -1,5 +1,5 @@ { "selections": { - "config_id": "1" + "config_id": "" } } \ No newline at end of file diff --git a/api/app/core/memory/main.py b/api/app/core/memory/main.py index ed61e584..063dfaeb 100644 --- a/api/app/core/memory/main.py +++ b/api/app/core/memory/main.py @@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false" import asyncio import time from datetime import datetime -from typing import Optional +from typing import Optional, Callable, Awaitable from dotenv import load_dotenv # 导入重构后的模块 @@ -50,7 +50,11 @@ logger = get_memory_logger(__name__) -async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): +async def main( + dialogue_text: Optional[str] = None, + is_pilot_run: bool = False, + progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None +): """ 记忆系统主流程 - 重构版本 @@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): is_pilot_run: 是否为试运行模式 - True: 试运行模式,不保存到 Neo4j - False: 正常运行模式,保存到 Neo4j + progress_callback: 可选的进度回调函数 + - 类型: Callable[[str, str, Optional[dict]], Awaitable[None]] + - 参数1 (stage): 当前处理阶段标识符 + - 参数2 (message): 人类可读的进度消息 + - 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果 + - 在管线关键点调用以报告进度和结果数据 工作流程: 1. 初始化客户端和配置 @@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): metadata={"source": "pilot_run", "input_type": "frontend_text"} ) + # 进度回调:开始预处理文本 + if progress_callback: + await progress_callback("text_preprocessing", "开始预处理文本...") + # 对前端传入的对话进行分块处理 chunked_dialogs = await get_chunked_dialogs_from_preprocessed( data=[dialog], @@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): llm_client=llm_client, ) logger.info(f"Processed frontend dialogue text: {len(messages)} messages") + + # 进度回调:输出每个分块的结果 + if progress_callback: + for dialog in chunked_dialogs: + for i, chunk in enumerate(dialog.chunks): + chunk_result = { + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dialog.id, + "chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + + # 进度回调:预处理文本完成 + preprocessing_summary = { + "total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs), + "total_dialogs": len(chunked_dialogs), + "chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY + } + await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) else: # 正常运行模式:从 testdata.json 文件加载 logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)") @@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): if not os.path.exists(test_data_path): raise FileNotFoundError(f"Test data file not found: {test_data_path}") + # 进度回调:开始预处理文本 + if progress_callback: + await progress_callback("text_preprocessing", "开始预处理文本...") + chunked_dialogs = await get_chunked_dialogs_with_preprocessing( chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY, group_id=config_defs.SELECTED_GROUP_ID, @@ -170,6 +209,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): skip_cleaning=True, ) logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json") + + # 进度回调:输出每个分块的结果 + if progress_callback: + for dialog in chunked_dialogs: + for i, chunk in enumerate(dialog.chunks): + chunk_result = { + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dialog.id, + "chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + + # 进度回调:预处理文本完成 + preprocessing_summary = { + "total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs), + "total_dialogs": len(chunked_dialogs), + "chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY + } + await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) @@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): embedder_client=embedder_client, connector=neo4j_connector, config=config, + progress_callback=progress_callback, # 传递进度回调 ) log_time("Orchestrator Initialization", time.time() - step_start, log_file) @@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): logger.info("Running extraction pipeline...") step_start = time.time() + + # 进度回调:正在知识抽取 + if progress_callback: + await progress_callback("knowledge_extraction", "正在知识抽取...") + extraction_result = await orchestrator.run( dialog_data_list=chunked_dialogs, is_pilot_run=is_pilot_run, # 传递试运行模式标志 @@ -216,6 +282,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) + + # 进度回调:生成结果 + if progress_callback: + await progress_callback("generating_results", "正在生成结果...") + # 步骤 5: 保存结果或输出结果 if is_pilot_run: diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 64a28590..9088a300 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -2,7 +2,7 @@ 去重功能函数 """ from app.core.memory.models.variate_config import DedupConfig -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Any from app.core.memory.models.graph_models import( StatementEntityEdge, EntityEntityEdge, @@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges( report_append: bool = False, report_stage_notes: List[str] | None = None, dedup_config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: +) -> Tuple[ + List[ExtractedEntityNode], + List[StatementEntityEdge], + List[EntityEntityEdge], + Dict[str, Any] # 新增:返回详细的去重消歧记录 +]: """ 主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧 返回:去重后的实体、语句→实体边、实体↔实体边。 @@ -981,8 +986,18 @@ async def deduplicate_entities_and_edges( append=report_append, stage_notes=report_stage_notes, ) + + # 构建详细的去重消歧记录(用于内存访问,避免解析日志文件) + dedup_details = { + "exact_merge_map": exact_merge_map, + "fuzzy_merge_records": fuzzy_merge_records, + "llm_decision_records": local_llm_records, + "disamb_records": disamb_records, + "id_redirect": id_redirect, + "blocked_pairs": blocked_pairs, + } - return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()) + return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details # 独立模块:去重融合报告写入(与实体/边的计算解耦) def _write_dedup_fusion_report( diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index a5f600b4..e4857ff3 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return( List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], + dict, # 新增:返回去重详情 ]: """ 执行两层实体去重与融合: @@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return( break # 第一层去重消歧 - dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges( + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, entity_entity_edges, @@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return( statement_chunk_edges, fused_statement_entity_edges, fused_entity_entity_edges, + dedup_details, # 返回去重详情 ) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 024a812b..7eec1189 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -12,13 +12,14 @@ 5. 提供错误处理和日志记录 6. 支持试运行模式(不写入数据库) -作者:Memory Refactoring Team +作者: 日期:2025-11-21 """ import asyncio import logging -from typing import List, Dict, Any, Tuple, Optional +import os +from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable from datetime import datetime from app.core.memory.models.message_models import DialogData @@ -94,6 +95,7 @@ class ExtractionOrchestrator: embedder_client: OpenAIEmbedderClient, connector: Neo4jConnector, config: Optional[ExtractionPipelineConfig] = None, + progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, ): """ 初始化流水线编排器 @@ -103,12 +105,21 @@ class ExtractionOrchestrator: embedder_client: 嵌入模型客户端 connector: Neo4j 连接器 config: 流水线配置,如果为 None 则使用默认配置 + progress_callback: 进度回调函数 + - 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None] + - 在管线关键点调用以报告进度和结果数据 """ self.llm_client = llm_client self.embedder_client = embedder_client self.connector = connector self.config = config or ExtractionPipelineConfig() self.is_pilot_run = False # 默认非试运行模式 + self.progress_callback = progress_callback # 保存进度回调函数 + + # 保存去重消歧的详细记录(内存中的数据结构) + self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录 + self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录 + self.id_redirect_map: Dict[str, str] = {} # ID重定向映射 # 初始化各个提取器 self.statement_extractor = StatementExtractor( @@ -160,6 +171,13 @@ class ExtractionOrchestrator: # 步骤 1: 陈述句提取 logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") dialog_data_list = await self._extract_statements(dialog_data_list) + + # 收集陈述句内容和统计数量 + all_statements_list = [] + for dialog in dialog_data_list: + for chunk in dialog.chunks: + all_statements_list.extend(chunk.statements) + total_statements = len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成") @@ -170,11 +188,90 @@ class ExtractionOrchestrator: chunk_embedding_maps, dialog_embeddings, ) = await self._parallel_extract_and_embed(dialog_data_list) + + # 收集实体和三元组内容,并统计数量 + all_entities_list = [] + all_triplets_list = [] + for triplet_map in triplet_maps: + for triplet_info in triplet_map.values(): + if triplet_info: + all_entities_list.extend(triplet_info.entities) + all_triplets_list.extend(triplet_info.triplets) + + total_entities = len(all_entities_list) + total_triplets = len(all_triplets_list) + total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") triplet_maps = await self._generate_entity_embeddings(triplet_maps) + # 进度回调:按三个阶段分别输出知识抽取结果 + if self.progress_callback: + # 第一阶段:陈述句提取结果 + for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句 + stmt_result = { + "extraction_type": "statement", + "statement_index": i + 1, + "statement": stmt.statement, + "statement_id": stmt.id + } + await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result) + + # 第二阶段:三元组提取结果 + for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组 + triplet_result = { + "extraction_type": "triplet", + "triplet_index": i + 1, + "subject": triplet.subject_name, + "predicate": triplet.predicate, + "object": triplet.object_name + } + await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result) + + # 第三阶段:时间提取结果 + if total_temporal > 0: + # 收集时间信息 + temporal_results = [] + for dialog in dialog_data_list: + for chunk in dialog.chunks: + for statement in chunk.statements: + if hasattr(statement, 'temporal_validity') and statement.temporal_validity: + temporal_results.append({ + "statement_id": statement.id, + "statement": statement.statement, + "valid_at": statement.temporal_validity.valid_at, + "invalid_at": statement.temporal_validity.invalid_at + }) + + # 输出时间提取结果 + for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果 + time_result = { + "extraction_type": "temporal", + "temporal_index": i + 1, + "statement": temporal_result["statement"], + "valid_at": temporal_result["valid_at"], + "invalid_at": temporal_result["invalid_at"] + } + await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) + else: + # 如果没有时间信息,也发送一个时间提取完成的消息 + time_result = { + "extraction_type": "temporal", + "temporal_index": 0, + "message": "未发现时间信息" + } + await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) + + # 进度回调:知识抽取完成,传递知识抽取的统计信息 + extraction_stats = { + "statements_count": total_statements, + "entities_count": total_entities, + "triplets_count": total_triplets, + "temporal_ranges_count": total_temporal, + } + await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) + # 步骤 4: 将提取的数据赋值到语句 logger.info("步骤 4/6: 数据赋值") dialog_data_list = await self._assign_extracted_data( @@ -218,6 +315,8 @@ class ExtractionOrchestrator: dialog_data_list, ) + + logger.info(f"知识提取流水线运行完成({mode_str})") return result @@ -732,6 +831,10 @@ class ExtractionOrchestrator: 包含所有节点和边的元组 """ logger.info("开始创建节点和边") + + # 进度回调:正在创建节点和边 + if self.progress_callback: + await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") dialogue_nodes = [] chunk_nodes = [] @@ -904,6 +1007,23 @@ class ExtractionOrchestrator: f"陈述句-实体边: {len(statement_entity_edges)}, " f"实体-实体边: {len(entity_entity_edges)}" ) + + # 进度回调:只输出关系创建结果 + if self.progress_callback: + # 输出关系创建结果 + await self._output_relationship_creation_results(entity_entity_edges, entity_nodes) + + # 进度回调:创建节点和边完成,传递结果统计 + nodes_edges_stats = { + "dialogue_nodes_count": len(dialogue_nodes), + "chunk_nodes_count": len(chunk_nodes), + "statement_nodes_count": len(statement_nodes), + "entity_nodes_count": len(entity_nodes), + "statement_chunk_edges_count": len(statement_chunk_edges), + "statement_entity_edges_count": len(statement_entity_edges), + "entity_entity_edges_count": len(entity_entity_edges), + } + await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats) return ( dialogue_nodes, @@ -950,6 +1070,11 @@ class ExtractionOrchestrator: - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) """ logger.info("开始两阶段实体去重和消歧") + + # 进度回调:正在去重消歧 + if self.progress_callback: + await self.progress_callback("deduplication", "正在去重消歧...") + logger.info( f"去重前: {len(entity_nodes)} 个实体节点, " f"{len(statement_entity_edges)} 条陈述句-实体边, " @@ -963,7 +1088,7 @@ class ExtractionOrchestrator: # 只执行第一层去重 from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges - dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges( + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, entity_entity_edges, @@ -972,6 +1097,9 @@ class ExtractionOrchestrator: dedup_config=self.config.deduplication, ) + # 保存去重消歧的详细记录到实例变量 + self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes) + result_tuple = ( dialogue_nodes, chunk_nodes, @@ -1009,7 +1137,11 @@ class ExtractionOrchestrator: _, final_statement_entity_edges, final_entity_entity_edges, + dedup_details, ) = result_tuple + + # 保存去重消歧的详细记录到实例变量 + self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) logger.info( f"去重后: {len(final_entity_nodes)} 个实体节点, " @@ -1021,6 +1153,46 @@ class ExtractionOrchestrator: f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) + + # 进度回调:输出去重消歧的具体结果 + if self.progress_callback: + # 分析实体合并情况 + merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) + + # 输出去重合并的实体示例 + for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 + dedup_result = { + "result_type": "entity_merge", + "merged_entity_name": merge_detail["main_entity_name"], + "merged_count": merge_detail["merged_count"], + "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" + } + await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result) + + # 分析实体消歧情况 + disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) + + # 输出实体消歧的结果 + for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 + disamb_result = { + "result_type": "entity_disambiguation", + "disambiguated_entity_name": disamb_detail["entity_name"], + "disambiguation_type": disamb_detail["disamb_type"], + "confidence": disamb_detail.get("confidence", "unknown"), + "reason": disamb_detail.get("reason", ""), + "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" + } + await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result) + + + + # 进度回调:去重消歧完成,传递去重和消歧的具体效果 + await self._send_dedup_progress_callback( + len(entity_nodes), len(final_entity_nodes), + len(statement_entity_edges), len(final_statement_entity_edges), + len(entity_entity_edges), len(final_entity_entity_edges) + ) + # 写入提取结果汇总(试运行和正式模式都需要生成) try: @@ -1041,6 +1213,378 @@ class ExtractionOrchestrator: logger.error(f"两阶段去重失败: {e}", exc_info=True) raise + def _save_dedup_details( + self, + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] + ): + """ + 保存去重消歧的详细记录到实例变量(基于内存数据结构) + + Args: + dedup_details: 去重函数返回的详细记录 + original_entities: 去重前的实体列表 + final_entities: 去重后的实体列表 + """ + try: + # 保存ID重定向映射 + self.id_redirect_map = dedup_details.get("id_redirect", {}) + + # 处理精确匹配的合并记录 + exact_merge_map = dedup_details.get("exact_merge_map", {}) + for key, info in exact_merge_map.items(): + merged_ids = info.get("merged_ids", set()) + if merged_ids: + self.dedup_merge_records.append({ + "type": "精确匹配", + "canonical_id": info.get("canonical_id"), + "entity_name": info.get("name"), + "entity_type": info.get("entity_type"), + "merged_count": len(merged_ids), + "merged_ids": list(merged_ids) + }) + + # 处理模糊匹配的合并记录 + fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", []) + for record in fuzzy_merge_records: + # 解析模糊匹配记录字符串 + # 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..." + try: + import re + match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record) + if match: + self.dedup_merge_records.append({ + "type": "模糊匹配", + "canonical_id": match.group(1), + "entity_name": match.group(3), + "entity_type": match.group(4), + "merged_count": 1, + "merged_ids": [match.group(5)] + }) + except Exception as e: + logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}") + + # 处理LLM去重的合并记录 + llm_decision_records = dedup_details.get("llm_decision_records", []) + for record in llm_decision_records: + if "[LLM去重]" in str(record): + try: + import re + # 格式: "[LLM去重] 同名类型相似 name1(type1)|name2(type2) | conf=0.xx | reason=..." + match = re.search(r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))", record) + if match: + self.dedup_merge_records.append({ + "type": "LLM去重", + "entity_name": match.group(1), + "entity_type": f"{match.group(2)}|{match.group(4)}", + "merged_count": 1, + "merged_ids": [] + }) + except Exception as e: + logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}") + + # 处理消歧记录 + disamb_records = dedup_details.get("disamb_records", []) + for record in disamb_records: + if "[DISAMB阻断]" in str(record): + try: + import re + # 格式: "[DISAMB阻断] name1(type1)|name2(type2) | conf=0.xx | reason=..." + content = str(record).replace("[DISAMB阻断]", "").strip() + match = re.search(r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content) + if match: + entity1_name = match.group(1).strip() + entity1_type = match.group(2) + entity2_name = match.group(3).strip() + entity2_type = match.group(4) + + # 提取置信度和原因 + conf_match = re.search(r"conf=([0-9.]+)", str(record)) + confidence = conf_match.group(1) if conf_match else "unknown" + + reason_match = re.search(r"reason=([^|]+)", str(record)) + reason = reason_match.group(1).strip() if reason_match else "" + + self.dedup_disamb_records.append({ + "entity_name": entity1_name, + "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", + "confidence": confidence, + "reason": reason[:100] + "..." if len(reason) > 100 else reason + }) + except Exception as e: + logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") + + logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") + + except Exception as e: + logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) + + async def _analyze_entity_merges( + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] + ) -> List[Dict[str, Any]]: + """ + 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) + + Args: + original_entities: 去重前的实体列表 + final_entities: 去重后的实体列表 + + Returns: + 合并详情列表,每个元素包含主实体名称和合并数量 + """ + try: + # 直接使用保存的合并记录 + if self.dedup_merge_records: + # 按合并数量排序,返回前几个 + sorted_records = sorted( + self.dedup_merge_records, + key=lambda x: x.get("merged_count", 0), + reverse=True + ) + + merge_info = [] + for record in sorted_records: + merge_info.append({ + "main_entity_name": record.get("entity_name", "未知实体"), + "merged_count": record.get("merged_count", 1) + }) + + return merge_info + + # 如果没有保存的记录,返回空列表 + logger.info("未找到实体合并记录") + return [] + + except Exception as e: + logger.error(f"分析实体合并情况失败: {e}", exc_info=True) + return [] + + async def _analyze_entity_disambiguation( + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] + ) -> List[Dict[str, Any]]: + """ + 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) + + Args: + original_entities: 去重前的实体列表 + final_entities: 去重后的实体列表 + + Returns: + 消歧详情列表,每个元素包含实体名称和消歧类型 + """ + try: + # 直接使用保存的消歧记录 + if self.dedup_disamb_records: + return self.dedup_disamb_records + + # 如果没有保存的记录,返回空列表 + logger.info("未找到实体消歧记录") + return [] + + except Exception as e: + logger.error(f"分析实体消歧情况失败: {e}", exc_info=True) + return [] + + def _get_entity_type_display_name(self, entity_type: str) -> str: + """ + 获取实体类型的中文显示名称 + + Args: + entity_type: 英文实体类型 + + Returns: + 中文显示名称 + """ + type_mapping = { + "Person": "人物实体节点", + "Organization": "组织实体节点", + "ORG": "组织实体节点", + "Location": "地点实体节点", + "LOC": "地点实体节点", + "Event": "事件实体节点", + "Concept": "概念实体节点", + "Time": "时间实体节点", + "Position": "职位实体节点", + "WorkRole": "职业实体节点", + "System": "系统实体节点", + "Policy": "政策实体节点", + "HistoricalPeriod": "历史时期实体节点", + "HistoricalState": "历史国家实体节点", + "HistoricalEvent": "历史事件实体节点", + "EconomicFactor": "经济因素实体节点", + "Condition": "条件实体节点", + "Numeric": "数值实体节点" + } + return type_mapping.get(entity_type, f"{entity_type}实体节点") + + async def _output_relationship_creation_results( + self, + entity_entity_edges: List[EntityEntityEdge], + entity_nodes: List[ExtractedEntityNode] + ): + """ + 输出关系创建结果 + + Args: + entity_entity_edges: 实体-实体边列表 + entity_nodes: 实体节点列表 + """ + try: + # 创建实体ID到名称的映射 + entity_id_to_name = {node.id: node.name for node in entity_nodes} + + # 输出关系创建结果 + for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系 + source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}") + target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}") + relation_type = edge.relation_type + + relationship_result = { + "result_type": "relationship_creation", + "relationship_index": i + 1, + "source_entity": source_name, + "relation_type": relation_type, + "target_entity": target_name, + "relationship_text": f"{source_name} -[{relation_type}]-> {target_name}" + } + + await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result) + + except Exception as e: + logger.error(f"输出关系创建结果失败: {e}", exc_info=True) + + async def _send_dedup_progress_callback( + self, + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, + ): + """ + 发送去重消歧完成的进度回调,传递具体的去重和消歧效果 + + Args: + original_entities: 去重前实体数量 + final_entities: 去重后实体数量 + original_stmt_edges: 去重前陈述句-实体边数量 + final_stmt_edges: 去重后陈述句-实体边数量 + original_ent_edges: 去重前实体-实体边数量 + final_ent_edges: 去重后实体-实体边数量 + """ + try: + # 解析去重消歧报告文件,获取具体的去重和消歧效果 + dedup_details = await self._parse_dedup_report() + + # 计算去重效果统计 + entities_reduced = original_entities - final_entities + stmt_edges_reduced = original_stmt_edges - final_stmt_edges + ent_edges_reduced = original_ent_edges - final_ent_edges + + # 构建进度回调数据 + dedup_stats = { + "entities": { + "original_count": original_entities, + "final_count": final_entities, + "reduced_count": entities_reduced, + "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, + }, + "statement_entity_edges": { + "original_count": original_stmt_edges, + "final_count": final_stmt_edges, + "reduced_count": stmt_edges_reduced, + }, + "entity_entity_edges": { + "original_count": original_ent_edges, + "final_count": final_ent_edges, + "reduced_count": ent_edges_reduced, + }, + "dedup_examples": dedup_details.get("dedup_examples", []), + "disamb_examples": dedup_details.get("disamb_examples", []), + "summary": { + "total_merges": dedup_details.get("total_merges", 0), + "total_disambiguations": dedup_details.get("total_disambiguations", 0), + } + } + + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) + + except Exception as e: + logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True) + # 即使解析失败,也发送基本的统计信息 + try: + basic_stats = { + "entities": { + "original_count": original_entities, + "final_count": final_entities, + "reduced_count": original_entities - final_entities, + }, + "summary": f"实体去重合并{original_entities - final_entities}个" + } + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats) + except Exception as e2: + logger.error(f"发送基本去重统计失败: {e2}", exc_info=True) + + async def _parse_dedup_report(self) -> Dict[str, Any]: + """ + 获取去重消歧报告,直接使用内存中的记录(不再解析日志文件) + + Returns: + 包含去重和消歧详细信息的字典 + """ + try: + # 直接使用保存的记录构建报告 + dedup_examples = [] + disamb_examples = [] + total_merges = 0 + total_disambiguations = 0 + + # 处理合并记录 + for record in self.dedup_merge_records: + merge_count = record.get("merged_count", 0) + total_merges += merge_count + + dedup_examples.append({ + "type": record.get("type", "未知"), + "entity_name": record.get("entity_name", "未知实体"), + "entity_type": record.get("entity_type", "未知类型"), + "merge_count": merge_count, + "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个" + }) + + # 处理消歧记录 + for record in self.dedup_disamb_records: + total_disambiguations += 1 + + # 从消歧类型中提取实体类型信息 + disamb_type = record.get("disamb_type", "") + entity_name = record.get("entity_name", "未知实体") + + disamb_examples.append({ + "entity1_name": entity_name, + "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", + "entity2_name": entity_name, + "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", + "description": f"{entity_name},消歧区分成功" + }) + + return { + "dedup_examples": dedup_examples[:5], # 只返回前5个示例 + "disamb_examples": disamb_examples[:5], # 只返回前5个示例 + "total_merges": total_merges, + "total_disambiguations": total_disambiguations, + } + + except Exception as e: + logger.error(f"获取去重报告失败: {e}", exc_info=True) + return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0} + # ============================================================================ # 数据加载和预处理函数 diff --git a/api/app/models/api_key_model.py b/api/app/models/api_key_model.py index f7cea634..791b99a0 100644 --- a/api/app/models/api_key_model.py +++ b/api/app/models/api_key_model.py @@ -13,7 +13,7 @@ from app.db import Base class ApiKeyType(StrEnum): """API Key 类型""" AGENT = "agent" # 智能体 - CLUSTER = "cluster" # 集群 + CLUSTER = "multi_agent" # 集群 WORKFLOW = "workflow" # 工作流 SERVICE = "service" # 服务 diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 3b0c1221..2e60ef1c 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -61,7 +61,7 @@ class ModelConfig(Base): # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") # 关联关系 api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan") diff --git a/api/app/repositories/api_key_repository.py b/api/app/repositories/api_key_repository.py index ad94fccf..757a840a 100644 --- a/api/app/repositories/api_key_repository.py +++ b/api/app/repositories/api_key_repository.py @@ -126,6 +126,7 @@ class ApiKeyRepository: "quota_used": api_key.quota_used, "quota_limit": api_key.quota_limit, "last_used_at": api_key.last_used_at, + "rate_limit": api_key.rate_limit, "avg_response_time": float(avg_response_time) if avg_response_time else None } diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 2ff773f3..66b2e45f 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel): conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") @model_validator(mode="before") + @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -60,6 +61,7 @@ class ConflictSchema(BaseModel): conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") @model_validator(mode="before") + @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel): resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.") @model_validator(mode="before") + @classmethod def _normalize_resolved(cls, v): if isinstance(v, dict): conflict = v.get("conflict") @@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型 def _now_ms() -> int: - return int(round(time.time() * 1000)) + return round(time.time() * 1000) def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse: diff --git a/api/app/services/api_key_service.py b/api/app/services/api_key_service.py index 09ba5ca1..2d7393e3 100644 --- a/api/app/services/api_key_service.py +++ b/api/app/services/api_key_service.py @@ -43,12 +43,13 @@ class ApiKeyService: existing = db.scalar( select(ApiKey).where( ApiKey.workspace_id == workspace_id, + ApiKey.resource_id == data.resource_id, ApiKey.name == data.name, ApiKey.is_active ) ) if existing: - raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME) + raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) # 生成 API Key api_key = generate_api_key(data.type) @@ -137,21 +138,19 @@ class ApiKeyService: """更新 API Key配置""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - if not api_key: - raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) - # 检查名称重复 if data.name and data.name != api_key.name: existing = db.scalar( select(ApiKey).where( ApiKey.workspace_id == workspace_id, + ApiKey.resource_id == data.resource_id, ApiKey.name == data.name, ApiKey.is_active, ApiKey.id != api_key_id ) ) if existing: - raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME) + raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) update_data = data.model_dump(exclude_unset=True) ApiKeyRepository.update(db, api_key_id, update_data) @@ -170,9 +169,6 @@ class ApiKeyService: """删除 API Key""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - if not api_key: - raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) - ApiKeyRepository.delete(db, api_key_id) db.commit() @@ -188,9 +184,6 @@ class ApiKeyService: """重新生成 API Key""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - if not api_key: - raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) - # 检查 API Key 是否激活 if not api_key.is_active: raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE) @@ -217,9 +210,6 @@ class ApiKeyService: """获取使用统计""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - if not api_key: - raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) - stats_data = ApiKeyRepository.get_stats(db, api_key_id) return api_key_schema.ApiKeyStats(**stats_data) @@ -236,9 +226,6 @@ class ApiKeyService: # 验证 API Key 权限 api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - if not api_key: - raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) - items, total = ApiKeyLogRepository.list_by_api_key( db, api_key_id, filters, page, pagesize ) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 51ca9619..0548b704 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -4,9 +4,12 @@ Memory Storage Service Handles business logic for memory storage operations. """ -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, AsyncGenerator import os import json +import asyncio +import time +from datetime import datetime from sqlalchemy.orm import Session from dotenv import load_dotenv @@ -14,6 +17,7 @@ from dotenv import load_dotenv from app.models.user_model import User from app.models.end_user_model import EndUser from app.core.logging_config import get_logger +from app.utils.sse_utils import format_sse_message from app.schemas.memory_storage_schema import ( ConfigFilter, ConfigPilotRun, @@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return self._convert_timestamps_to_format(data_list) - async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]: + async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]: """ - 选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。 - 支持 dialogue_text 参数用于试运行模式。 + 流式执行试运行,产生 SSE 格式的进度事件 + + Args: + payload: 试运行配置和对话文本 + + Yields: + SSE 格式的字符串,包含以下事件类型: + - 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等) + - result: 最终结果 + - error: 错误信息 + - done: 完成标记 + + Raises: + ValueError: 当配置无效或参数缺失时 + RuntimeError: 当管线执行失败时 """ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json") + + try: + # 发出初始进度事件 + yield format_sse_message("starting", { + "message": "开始试运行...", + "time": int(time.time() * 1000) + }) + + # 步骤 1: 配置加载和验证(复用现有逻辑) + payload_cid = str(getattr(payload, "config_id", "") or "").strip() + cid: Optional[str] = payload_cid if payload_cid else None - payload_cid = str(getattr(payload, "config_id", "") or "").strip() - cid: Optional[str] = payload_cid if payload_cid else None + if not cid and os.path.isfile(dbrun_path): + try: + with open(dbrun_path, "r", encoding="utf-8") as f: + dbrun = json.load(f) + if isinstance(dbrun, dict): + sel = dbrun.get("selections", {}) + if isinstance(sel, dict): + fallback_cid = str(sel.get("config_id") or "").strip() + cid = fallback_cid or None + except Exception: + cid = None - if not cid and os.path.isfile(dbrun_path): - try: - with open(dbrun_path, "r", encoding="utf-8") as f: - dbrun = json.load(f) - if isinstance(dbrun, dict): - sel = dbrun.get("selections", {}) - if isinstance(sel, dict): - fallback_cid = str(sel.get("config_id") or "").strip() - cid = fallback_cid or None - except Exception: - cid = None + if not cid: + raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行") - if not cid: - raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行") + # 验证 dialogue_text 必须提供 + dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" + logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}") + if not dialogue_text: + raise ValueError("试运行模式必须提供 dialogue_text 参数") - # 验证 dialogue_text 必须提供 - dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}") - if not dialogue_text: - raise ValueError("试运行模式必须提供 dialogue_text 参数") + # 应用内存覆写并刷新常量 + from app.core.memory.utils.config.definitions import reload_configuration_from_database + + ok_override = reload_configuration_from_database(cid) + if not ok_override: + raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败") - # 应用内存覆写并刷新常量(在导入主管线前) - # 注意:仅在内存中覆写配置,不修改 runtime.json 文件 - from app.core.memory.utils.config.definitions import reload_configuration_from_database - - ok_override = reload_configuration_from_database(cid) - if not ok_override: - raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败") - - # 导入并 await 主管线(使用当前 ASGI 事件循环) - from app.core.memory.main import main as pipeline_main - from app.core.memory.utils.self_reflexion_utils import reflexion - - logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True") - await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True) - logger.info("[PILOT_RUN] pipeline_main completed") - - # 调用自我反思 - # data = [ - # { - # "data": { - # "id": "1", - # "statement": "张明现在在谷歌工作。", - # "group_id": "1", - # "chunk_id": "10", - # "created_at": "2023-01-01", - # "expired_at": "2023-01-02", - # "valid_at": "2023-01-01", - # "invalid_at": "2023-01-02", - # "entity_ids": [] - # }, - # "conflict": True, - # "conflict_memory": { - # "id": "1", - # "statement": "张明现在在清华大学当讲师。", - # "group_id": "1", - # "chunk_id": "1", - # "created_at": "2019-12-01T19:15:05.213210", - # "expired_at": None, - # "valid_at": None, - # "invalid_at": None, - # "entity_ids": [] - # } - # } - # ] - from app.core.memory.utils.config.get_example_data import get_example_data - data = get_example_data() - reflexion_result = await reflexion(data) - - # 读取输出,使用全局配置路径 - from app.core.config import settings - result_path = settings.get_memory_output_path("extracted_result.json") - if not os.path.isfile(result_path): - raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - - with open(result_path, "r", encoding="utf-8") as rf: - extracted_result = json.load(rf) - - extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None - return { - "config_id": cid, - "time_log": os.path.join(project_root, "time.log"), - "extracted_result": extracted_result, - } + # 步骤 2: 创建进度回调函数捕获管线进度 + # 使用队列在回调和生成器之间传递进度事件 + progress_queue: asyncio.Queue = asyncio.Queue() + + async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: + """ + 进度回调函数,将进度事件放入队列 + + Args: + stage: 阶段标识 + message: 进度消息 + data: 可选的结果数据(用于传递节点执行结果) + """ + await progress_queue.put((stage, message, data)) + + # 步骤 3: 在后台任务中执行管线 + async def run_pipeline(): + """在后台执行管线并捕获异常""" + try: + from app.core.memory.main import main as pipeline_main + + logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True") + await pipeline_main( + dialogue_text=dialogue_text, + is_pilot_run=True, + progress_callback=progress_callback + ) + logger.info("[PILOT_RUN_STREAM] pipeline_main completed") + + # 标记管线完成 + await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) + except Exception as e: + # 将异常放入队列 + await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) + + # 启动后台任务 + pipeline_task = asyncio.create_task(run_pipeline()) + + # 步骤 4: 从队列中读取进度事件并发出 + while True: + try: + # 等待进度事件,设置超时以检测客户端断开 + stage, message, data = await asyncio.wait_for( + progress_queue.get(), + timeout=0.5 + ) + + # 检查特殊标记 + if stage == "__PIPELINE_COMPLETE__": + break + elif stage == "__PIPELINE_ERROR__": + raise RuntimeError(message) + + # 构建进度事件数据 + progress_data = { + "message": message, + "time": int(time.time() * 1000) + } + + # 如果有结果数据,添加到事件中 + if data: + progress_data["data"] = data + + # 发出进度事件,使用 stage 作为事件类型 + yield format_sse_message(stage, progress_data) + + except TimeoutError: + # 超时,继续等待(这允许检测客户端断开) + continue + + # 等待管线任务完成 + await pipeline_task + + # 步骤 5: 读取提取结果 + from app.core.config import settings + result_path = settings.get_memory_output_path("extracted_result.json") + if not os.path.isfile(result_path): + raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") + + with open(result_path, "r", encoding="utf-8") as rf: + extracted_result = json.load(rf) + + # 步骤 6: 发出结果事件 + result_data = { + "config_id": cid, + "time_log": os.path.join(project_root, "logs", "time.log"), + "extracted_result": extracted_result, + } + yield format_sse_message("result", result_data) + + # 步骤 7: 发出完成事件 + yield format_sse_message("done", { + "message": "试运行完成", + "time": int(time.time() * 1000) + }) + + except asyncio.CancelledError: + # 客户端断开连接 + logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") + raise + except Exception as e: + # 发出错误事件 + logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True) + yield format_sse_message("error", { + "code": 5000, + "message": "试运行失败", + "error": str(e), + "time": int(time.time() * 1000) + }) # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- diff --git a/api/app/utils/sse_utils.py b/api/app/utils/sse_utils.py new file mode 100644 index 00000000..43444f27 --- /dev/null +++ b/api/app/utils/sse_utils.py @@ -0,0 +1,27 @@ +""" +Server-Sent Events (SSE) Utility Functions + +Provides shared utilities for formatting and handling SSE messages. +""" + +import json +from typing import Dict, Any + + +def format_sse_message(event_type: str, data: Dict[str, Any]) -> str: + """ + Format a message in Server-Sent Events (SSE) format. + + Args: + event_type: Type of event (stage name, result, error, done) + data: Event data dictionary to be serialized as JSON + + Returns: + SSE formatted string: "event: \\ndata: \\n\\n" + + Example: + >>> format_sse_message("loading", {"message": "Loading..."}) + 'event: loading\\ndata: {"message": "Loading..."}\\n\\n' + """ + json_data = json.dumps(data, ensure_ascii=False) + return f"event: {event_type}\ndata: {json_data}\n\n"