[MODIFY] MEM SEE OUTPUT
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
from fastapi import APIRouter, Depends, UploadFile
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -322,36 +323,24 @@ def read_all_config(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
|
@router.post("/pilot_run", response_model=None)
|
||||||
async def pilot_run(
|
async def pilot_run(
|
||||||
payload: ConfigPilotRun,
|
payload: ConfigPilotRun,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> StreamingResponse:
|
||||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||||
|
|
||||||
# 先尝试从数据库加载配置
|
svc = DataConfigService(db)
|
||||||
try:
|
return StreamingResponse(
|
||||||
config_loaded = reload_configuration_from_database(str(payload.config_id))
|
svc.pilot_run_stream(payload),
|
||||||
if not config_loaded:
|
media_type="text/event-stream",
|
||||||
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
|
headers={
|
||||||
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
|
"Cache-Control": "no-cache",
|
||||||
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
|
"Connection": "keep-alive",
|
||||||
except Exception as e:
|
"X-Accel-Buffering": "no"
|
||||||
api_logger.error(f"Exception while loading configuration: {str(e)}")
|
}
|
||||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
svc = DataConfigService(db)
|
|
||||||
result = await svc.pilot_run(payload)
|
|
||||||
return success(data=result, msg="试运行完成")
|
|
||||||
except ValueError as e:
|
|
||||||
# 捕获参数验证错误
|
|
||||||
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Pilot run failed: {str(e)}")
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def generate_api_key(key_type: ApiKeyType) -> str:
|
|||||||
# 前缀映射
|
# 前缀映射
|
||||||
prefix_map = {
|
prefix_map = {
|
||||||
ApiKeyType.AGENT: "sk-agent-",
|
ApiKeyType.AGENT: "sk-agent-",
|
||||||
ApiKeyType.CLUSTER: "sk-cluster-",
|
ApiKeyType.CLUSTER: "sk-multi_agent-",
|
||||||
ApiKeyType.WORKFLOW: "sk-workflow-",
|
ApiKeyType.WORKFLOW: "sk-workflow-",
|
||||||
ApiKeyType.SERVICE: "sk-service-"
|
ApiKeyType.SERVICE: "sk-service-"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
|||||||
all_statement_chunk_edges,
|
all_statement_chunk_edges,
|
||||||
all_statement_entity_edges,
|
all_statement_entity_edges,
|
||||||
all_entity_entity_edges,
|
all_entity_entity_edges,
|
||||||
|
all_dedup_details,
|
||||||
|
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"selections": {
|
"selections": {
|
||||||
"config_id": "1"
|
"config_id": ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false"
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional, Callable, Awaitable
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# 导入重构后的模块
|
# 导入重构后的模块
|
||||||
@@ -50,7 +50,11 @@ logger = get_memory_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
async def main(
|
||||||
|
dialogue_text: Optional[str] = None,
|
||||||
|
is_pilot_run: bool = False,
|
||||||
|
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
记忆系统主流程 - 重构版本
|
记忆系统主流程 - 重构版本
|
||||||
|
|
||||||
@@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
is_pilot_run: 是否为试运行模式
|
is_pilot_run: 是否为试运行模式
|
||||||
- True: 试运行模式,不保存到 Neo4j
|
- True: 试运行模式,不保存到 Neo4j
|
||||||
- False: 正常运行模式,保存到 Neo4j
|
- False: 正常运行模式,保存到 Neo4j
|
||||||
|
progress_callback: 可选的进度回调函数
|
||||||
|
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||||
|
- 参数1 (stage): 当前处理阶段标识符
|
||||||
|
- 参数2 (message): 人类可读的进度消息
|
||||||
|
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||||
|
- 在管线关键点调用以报告进度和结果数据
|
||||||
|
|
||||||
工作流程:
|
工作流程:
|
||||||
1. 初始化客户端和配置
|
1. 初始化客户端和配置
|
||||||
@@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 进度回调:开始预处理文本
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||||
|
|
||||||
# 对前端传入的对话进行分块处理
|
# 对前端传入的对话进行分块处理
|
||||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||||
data=[dialog],
|
data=[dialog],
|
||||||
@@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
)
|
)
|
||||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||||
|
|
||||||
|
# 进度回调:输出每个分块的结果
|
||||||
|
if progress_callback:
|
||||||
|
for dialog in chunked_dialogs:
|
||||||
|
for i, chunk in enumerate(dialog.chunks):
|
||||||
|
chunk_result = {
|
||||||
|
"chunk_index": i + 1,
|
||||||
|
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||||
|
"full_length": len(chunk.content),
|
||||||
|
"dialog_id": dialog.id,
|
||||||
|
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||||
|
|
||||||
|
# 进度回调:预处理文本完成
|
||||||
|
preprocessing_summary = {
|
||||||
|
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||||
|
"total_dialogs": len(chunked_dialogs),
|
||||||
|
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||||
else:
|
else:
|
||||||
# 正常运行模式:从 testdata.json 文件加载
|
# 正常运行模式:从 testdata.json 文件加载
|
||||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||||
@@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
if not os.path.exists(test_data_path):
|
if not os.path.exists(test_data_path):
|
||||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||||
|
|
||||||
|
# 进度回调:开始预处理文本
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||||
|
|
||||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||||
group_id=config_defs.SELECTED_GROUP_ID,
|
group_id=config_defs.SELECTED_GROUP_ID,
|
||||||
@@ -171,6 +210,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
)
|
)
|
||||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||||
|
|
||||||
|
# 进度回调:输出每个分块的结果
|
||||||
|
if progress_callback:
|
||||||
|
for dialog in chunked_dialogs:
|
||||||
|
for i, chunk in enumerate(dialog.chunks):
|
||||||
|
chunk_result = {
|
||||||
|
"chunk_index": i + 1,
|
||||||
|
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||||
|
"full_length": len(chunk.content),
|
||||||
|
"dialog_id": dialog.id,
|
||||||
|
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||||
|
|
||||||
|
# 进度回调:预处理文本完成
|
||||||
|
preprocessing_summary = {
|
||||||
|
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||||
|
"total_dialogs": len(chunked_dialogs),
|
||||||
|
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||||
|
|
||||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# 步骤 3: 初始化流水线编排器
|
# 步骤 3: 初始化流水线编排器
|
||||||
@@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
embedder_client=embedder_client,
|
embedder_client=embedder_client,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
config=config,
|
config=config,
|
||||||
|
progress_callback=progress_callback, # 传递进度回调
|
||||||
)
|
)
|
||||||
|
|
||||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||||
@@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
logger.info("Running extraction pipeline...")
|
logger.info("Running extraction pipeline...")
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
# 进度回调:正在知识抽取
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||||
|
|
||||||
extraction_result = await orchestrator.run(
|
extraction_result = await orchestrator.run(
|
||||||
dialog_data_list=chunked_dialogs,
|
dialog_data_list=chunked_dialogs,
|
||||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||||
@@ -217,6 +283,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
|||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# 进度回调:生成结果
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("generating_results", "正在生成结果...")
|
||||||
|
|
||||||
|
|
||||||
# 步骤 5: 保存结果或输出结果
|
# 步骤 5: 保存结果或输出结果
|
||||||
if is_pilot_run:
|
if is_pilot_run:
|
||||||
logger.info("Pilot run mode: Skipping Neo4j save")
|
logger.info("Pilot run mode: Skipping Neo4j save")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
去重功能函数
|
去重功能函数
|
||||||
"""
|
"""
|
||||||
from app.core.memory.models.variate_config import DedupConfig
|
from app.core.memory.models.variate_config import DedupConfig
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple, Any
|
||||||
from app.core.memory.models.graph_models import(
|
from app.core.memory.models.graph_models import(
|
||||||
StatementEntityEdge,
|
StatementEntityEdge,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
@@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges(
|
|||||||
report_append: bool = False,
|
report_append: bool = False,
|
||||||
report_stage_notes: List[str] | None = None,
|
report_stage_notes: List[str] | None = None,
|
||||||
dedup_config: DedupConfig | None = None,
|
dedup_config: DedupConfig | None = None,
|
||||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
) -> Tuple[
|
||||||
|
List[ExtractedEntityNode],
|
||||||
|
List[StatementEntityEdge],
|
||||||
|
List[EntityEntityEdge],
|
||||||
|
Dict[str, Any] # 新增:返回详细的去重消歧记录
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
|
主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
|
||||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||||
@@ -982,7 +987,17 @@ async def deduplicate_entities_and_edges(
|
|||||||
stage_notes=report_stage_notes,
|
stage_notes=report_stage_notes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
|
# 构建详细的去重消歧记录(用于内存访问,避免解析日志文件)
|
||||||
|
dedup_details = {
|
||||||
|
"exact_merge_map": exact_merge_map,
|
||||||
|
"fuzzy_merge_records": fuzzy_merge_records,
|
||||||
|
"llm_decision_records": local_llm_records,
|
||||||
|
"disamb_records": disamb_records,
|
||||||
|
"id_redirect": id_redirect,
|
||||||
|
"blocked_pairs": blocked_pairs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details
|
||||||
|
|
||||||
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
|
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
|
||||||
def _write_dedup_fusion_report(
|
def _write_dedup_fusion_report(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
|
dict, # 新增:返回去重详情
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两层实体去重与融合:
|
执行两层实体去重与融合:
|
||||||
@@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
break
|
break
|
||||||
|
|
||||||
# 第一层去重消歧
|
# 第一层去重消歧
|
||||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
@@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
fused_statement_entity_edges,
|
fused_statement_entity_edges,
|
||||||
fused_entity_entity_edges,
|
fused_entity_entity_edges,
|
||||||
|
dedup_details, # 返回去重详情
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,13 +12,14 @@
|
|||||||
5. 提供错误处理和日志记录
|
5. 提供错误处理和日志记录
|
||||||
6. 支持试运行模式(不写入数据库)
|
6. 支持试运行模式(不写入数据库)
|
||||||
|
|
||||||
作者:Memory Refactoring Team
|
作者:
|
||||||
日期:2025-11-21
|
日期:2025-11-21
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
import os
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
|
|||||||
embedder_client: OpenAIEmbedderClient,
|
embedder_client: OpenAIEmbedderClient,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config: Optional[ExtractionPipelineConfig] = None,
|
config: Optional[ExtractionPipelineConfig] = None,
|
||||||
|
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流水线编排器
|
初始化流水线编排器
|
||||||
@@ -103,12 +105,21 @@ class ExtractionOrchestrator:
|
|||||||
embedder_client: 嵌入模型客户端
|
embedder_client: 嵌入模型客户端
|
||||||
connector: Neo4j 连接器
|
connector: Neo4j 连接器
|
||||||
config: 流水线配置,如果为 None 则使用默认配置
|
config: 流水线配置,如果为 None 则使用默认配置
|
||||||
|
progress_callback: 进度回调函数
|
||||||
|
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||||
|
- 在管线关键点调用以报告进度和结果数据
|
||||||
"""
|
"""
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
self.embedder_client = embedder_client
|
self.embedder_client = embedder_client
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.config = config or ExtractionPipelineConfig()
|
self.config = config or ExtractionPipelineConfig()
|
||||||
self.is_pilot_run = False # 默认非试运行模式
|
self.is_pilot_run = False # 默认非试运行模式
|
||||||
|
self.progress_callback = progress_callback # 保存进度回调函数
|
||||||
|
|
||||||
|
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||||
|
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||||
|
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
|
||||||
|
self.id_redirect_map: Dict[str, str] = {} # ID重定向映射
|
||||||
|
|
||||||
# 初始化各个提取器
|
# 初始化各个提取器
|
||||||
self.statement_extractor = StatementExtractor(
|
self.statement_extractor = StatementExtractor(
|
||||||
@@ -161,6 +172,13 @@ class ExtractionOrchestrator:
|
|||||||
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
|
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
|
||||||
dialog_data_list = await self._extract_statements(dialog_data_list)
|
dialog_data_list = await self._extract_statements(dialog_data_list)
|
||||||
|
|
||||||
|
# 收集陈述句内容和统计数量
|
||||||
|
all_statements_list = []
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
for chunk in dialog.chunks:
|
||||||
|
all_statements_list.extend(chunk.statements)
|
||||||
|
total_statements = len(all_statements_list)
|
||||||
|
|
||||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||||
(
|
(
|
||||||
@@ -171,10 +189,89 @@ class ExtractionOrchestrator:
|
|||||||
dialog_embeddings,
|
dialog_embeddings,
|
||||||
) = await self._parallel_extract_and_embed(dialog_data_list)
|
) = await self._parallel_extract_and_embed(dialog_data_list)
|
||||||
|
|
||||||
|
# 收集实体和三元组内容,并统计数量
|
||||||
|
all_entities_list = []
|
||||||
|
all_triplets_list = []
|
||||||
|
for triplet_map in triplet_maps:
|
||||||
|
for triplet_info in triplet_map.values():
|
||||||
|
if triplet_info:
|
||||||
|
all_entities_list.extend(triplet_info.entities)
|
||||||
|
all_triplets_list.extend(triplet_info.triplets)
|
||||||
|
|
||||||
|
total_entities = len(all_entities_list)
|
||||||
|
total_triplets = len(all_triplets_list)
|
||||||
|
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||||
|
|
||||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||||
logger.info("步骤 3/6: 生成实体嵌入")
|
logger.info("步骤 3/6: 生成实体嵌入")
|
||||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||||
|
|
||||||
|
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||||
|
if self.progress_callback:
|
||||||
|
# 第一阶段:陈述句提取结果
|
||||||
|
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||||
|
stmt_result = {
|
||||||
|
"extraction_type": "statement",
|
||||||
|
"statement_index": i + 1,
|
||||||
|
"statement": stmt.statement,
|
||||||
|
"statement_id": stmt.id
|
||||||
|
}
|
||||||
|
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||||
|
|
||||||
|
# 第二阶段:三元组提取结果
|
||||||
|
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||||
|
triplet_result = {
|
||||||
|
"extraction_type": "triplet",
|
||||||
|
"triplet_index": i + 1,
|
||||||
|
"subject": triplet.subject_name,
|
||||||
|
"predicate": triplet.predicate,
|
||||||
|
"object": triplet.object_name
|
||||||
|
}
|
||||||
|
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||||
|
|
||||||
|
# 第三阶段:时间提取结果
|
||||||
|
if total_temporal > 0:
|
||||||
|
# 收集时间信息
|
||||||
|
temporal_results = []
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
for chunk in dialog.chunks:
|
||||||
|
for statement in chunk.statements:
|
||||||
|
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||||
|
temporal_results.append({
|
||||||
|
"statement_id": statement.id,
|
||||||
|
"statement": statement.statement,
|
||||||
|
"valid_at": statement.temporal_validity.valid_at,
|
||||||
|
"invalid_at": statement.temporal_validity.invalid_at
|
||||||
|
})
|
||||||
|
|
||||||
|
# 输出时间提取结果
|
||||||
|
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||||
|
time_result = {
|
||||||
|
"extraction_type": "temporal",
|
||||||
|
"temporal_index": i + 1,
|
||||||
|
"statement": temporal_result["statement"],
|
||||||
|
"valid_at": temporal_result["valid_at"],
|
||||||
|
"invalid_at": temporal_result["invalid_at"]
|
||||||
|
}
|
||||||
|
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||||
|
else:
|
||||||
|
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||||
|
time_result = {
|
||||||
|
"extraction_type": "temporal",
|
||||||
|
"temporal_index": 0,
|
||||||
|
"message": "未发现时间信息"
|
||||||
|
}
|
||||||
|
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||||
|
|
||||||
|
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||||
|
extraction_stats = {
|
||||||
|
"statements_count": total_statements,
|
||||||
|
"entities_count": total_entities,
|
||||||
|
"triplets_count": total_triplets,
|
||||||
|
"temporal_ranges_count": total_temporal,
|
||||||
|
}
|
||||||
|
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||||
|
|
||||||
# 步骤 4: 将提取的数据赋值到语句
|
# 步骤 4: 将提取的数据赋值到语句
|
||||||
logger.info("步骤 4/6: 数据赋值")
|
logger.info("步骤 4/6: 数据赋值")
|
||||||
dialog_data_list = await self._assign_extracted_data(
|
dialog_data_list = await self._assign_extracted_data(
|
||||||
@@ -218,6 +315,8 @@ class ExtractionOrchestrator:
|
|||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -733,6 +832,10 @@ class ExtractionOrchestrator:
|
|||||||
"""
|
"""
|
||||||
logger.info("开始创建节点和边")
|
logger.info("开始创建节点和边")
|
||||||
|
|
||||||
|
# 进度回调:正在创建节点和边
|
||||||
|
if self.progress_callback:
|
||||||
|
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||||
|
|
||||||
dialogue_nodes = []
|
dialogue_nodes = []
|
||||||
chunk_nodes = []
|
chunk_nodes = []
|
||||||
statement_nodes = []
|
statement_nodes = []
|
||||||
@@ -905,6 +1008,23 @@ class ExtractionOrchestrator:
|
|||||||
f"实体-实体边: {len(entity_entity_edges)}"
|
f"实体-实体边: {len(entity_entity_edges)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 进度回调:只输出关系创建结果
|
||||||
|
if self.progress_callback:
|
||||||
|
# 输出关系创建结果
|
||||||
|
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||||
|
|
||||||
|
# 进度回调:创建节点和边完成,传递结果统计
|
||||||
|
nodes_edges_stats = {
|
||||||
|
"dialogue_nodes_count": len(dialogue_nodes),
|
||||||
|
"chunk_nodes_count": len(chunk_nodes),
|
||||||
|
"statement_nodes_count": len(statement_nodes),
|
||||||
|
"entity_nodes_count": len(entity_nodes),
|
||||||
|
"statement_chunk_edges_count": len(statement_chunk_edges),
|
||||||
|
"statement_entity_edges_count": len(statement_entity_edges),
|
||||||
|
"entity_entity_edges_count": len(entity_entity_edges),
|
||||||
|
}
|
||||||
|
await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
@@ -950,6 +1070,11 @@ class ExtractionOrchestrator:
|
|||||||
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
|
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
|
||||||
"""
|
"""
|
||||||
logger.info("开始两阶段实体去重和消歧")
|
logger.info("开始两阶段实体去重和消歧")
|
||||||
|
|
||||||
|
# 进度回调:正在去重消歧
|
||||||
|
if self.progress_callback:
|
||||||
|
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"去重前: {len(entity_nodes)} 个实体节点, "
|
f"去重前: {len(entity_nodes)} 个实体节点, "
|
||||||
f"{len(statement_entity_edges)} 条陈述句-实体边, "
|
f"{len(statement_entity_edges)} 条陈述句-实体边, "
|
||||||
@@ -963,7 +1088,7 @@ class ExtractionOrchestrator:
|
|||||||
# 只执行第一层去重
|
# 只执行第一层去重
|
||||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||||
|
|
||||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
@@ -972,6 +1097,9 @@ class ExtractionOrchestrator:
|
|||||||
dedup_config=self.config.deduplication,
|
dedup_config=self.config.deduplication,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 保存去重消歧的详细记录到实例变量
|
||||||
|
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
|
||||||
|
|
||||||
result_tuple = (
|
result_tuple = (
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
@@ -1009,8 +1137,12 @@ class ExtractionOrchestrator:
|
|||||||
_,
|
_,
|
||||||
final_statement_entity_edges,
|
final_statement_entity_edges,
|
||||||
final_entity_entity_edges,
|
final_entity_entity_edges,
|
||||||
|
dedup_details,
|
||||||
) = result_tuple
|
) = result_tuple
|
||||||
|
|
||||||
|
# 保存去重消歧的详细记录到实例变量
|
||||||
|
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||||
@@ -1022,6 +1154,46 @@ class ExtractionOrchestrator:
|
|||||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 进度回调:输出去重消歧的具体结果
|
||||||
|
if self.progress_callback:
|
||||||
|
# 分析实体合并情况
|
||||||
|
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||||
|
|
||||||
|
# 输出去重合并的实体示例
|
||||||
|
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||||
|
dedup_result = {
|
||||||
|
"result_type": "entity_merge",
|
||||||
|
"merged_entity_name": merge_detail["main_entity_name"],
|
||||||
|
"merged_count": merge_detail["merged_count"],
|
||||||
|
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||||
|
}
|
||||||
|
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||||
|
|
||||||
|
# 分析实体消歧情况
|
||||||
|
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||||
|
|
||||||
|
# 输出实体消歧的结果
|
||||||
|
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||||
|
disamb_result = {
|
||||||
|
"result_type": "entity_disambiguation",
|
||||||
|
"disambiguated_entity_name": disamb_detail["entity_name"],
|
||||||
|
"disambiguation_type": disamb_detail["disamb_type"],
|
||||||
|
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||||
|
"reason": disamb_detail.get("reason", ""),
|
||||||
|
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||||
|
}
|
||||||
|
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||||
|
await self._send_dedup_progress_callback(
|
||||||
|
len(entity_nodes), len(final_entity_nodes),
|
||||||
|
len(statement_entity_edges), len(final_statement_entity_edges),
|
||||||
|
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -1041,6 +1213,378 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"两阶段去重失败: {e}", exc_info=True)
|
logger.error(f"两阶段去重失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _save_dedup_details(
|
||||||
|
self,
|
||||||
|
dedup_details: Dict[str, Any],
|
||||||
|
original_entities: List[ExtractedEntityNode],
|
||||||
|
final_entities: List[ExtractedEntityNode]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dedup_details: 去重函数返回的详细记录
|
||||||
|
original_entities: 去重前的实体列表
|
||||||
|
final_entities: 去重后的实体列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 保存ID重定向映射
|
||||||
|
self.id_redirect_map = dedup_details.get("id_redirect", {})
|
||||||
|
|
||||||
|
# 处理精确匹配的合并记录
|
||||||
|
exact_merge_map = dedup_details.get("exact_merge_map", {})
|
||||||
|
for key, info in exact_merge_map.items():
|
||||||
|
merged_ids = info.get("merged_ids", set())
|
||||||
|
if merged_ids:
|
||||||
|
self.dedup_merge_records.append({
|
||||||
|
"type": "精确匹配",
|
||||||
|
"canonical_id": info.get("canonical_id"),
|
||||||
|
"entity_name": info.get("name"),
|
||||||
|
"entity_type": info.get("entity_type"),
|
||||||
|
"merged_count": len(merged_ids),
|
||||||
|
"merged_ids": list(merged_ids)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理模糊匹配的合并记录
|
||||||
|
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
|
||||||
|
for record in fuzzy_merge_records:
|
||||||
|
# 解析模糊匹配记录字符串
|
||||||
|
# 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..."
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record)
|
||||||
|
if match:
|
||||||
|
self.dedup_merge_records.append({
|
||||||
|
"type": "模糊匹配",
|
||||||
|
"canonical_id": match.group(1),
|
||||||
|
"entity_name": match.group(3),
|
||||||
|
"entity_type": match.group(4),
|
||||||
|
"merged_count": 1,
|
||||||
|
"merged_ids": [match.group(5)]
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
|
# 处理LLM去重的合并记录
|
||||||
|
llm_decision_records = dedup_details.get("llm_decision_records", [])
|
||||||
|
for record in llm_decision_records:
|
||||||
|
if "[LLM去重]" in str(record):
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
# 格式: "[LLM去重] 同名类型相似 name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||||
|
match = re.search(r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))", record)
|
||||||
|
if match:
|
||||||
|
self.dedup_merge_records.append({
|
||||||
|
"type": "LLM去重",
|
||||||
|
"entity_name": match.group(1),
|
||||||
|
"entity_type": f"{match.group(2)}|{match.group(4)}",
|
||||||
|
"merged_count": 1,
|
||||||
|
"merged_ids": []
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
|
# 处理消歧记录
|
||||||
|
disamb_records = dedup_details.get("disamb_records", [])
|
||||||
|
for record in disamb_records:
|
||||||
|
if "[DISAMB阻断]" in str(record):
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
# 格式: "[DISAMB阻断] name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||||
|
content = str(record).replace("[DISAMB阻断]", "").strip()
|
||||||
|
match = re.search(r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content)
|
||||||
|
if match:
|
||||||
|
entity1_name = match.group(1).strip()
|
||||||
|
entity1_type = match.group(2)
|
||||||
|
entity2_name = match.group(3).strip()
|
||||||
|
entity2_type = match.group(4)
|
||||||
|
|
||||||
|
# 提取置信度和原因
|
||||||
|
conf_match = re.search(r"conf=([0-9.]+)", str(record))
|
||||||
|
confidence = conf_match.group(1) if conf_match else "unknown"
|
||||||
|
|
||||||
|
reason_match = re.search(r"reason=([^|]+)", str(record))
|
||||||
|
reason = reason_match.group(1).strip() if reason_match else ""
|
||||||
|
|
||||||
|
self.dedup_disamb_records.append({
|
||||||
|
"entity_name": entity1_name,
|
||||||
|
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
|
||||||
|
"confidence": confidence,
|
||||||
|
"reason": reason[:100] + "..." if len(reason) > 100 else reason
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _analyze_entity_merges(
|
||||||
|
self,
|
||||||
|
original_entities: List[ExtractedEntityNode],
|
||||||
|
final_entities: List[ExtractedEntityNode]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_entities: 去重前的实体列表
|
||||||
|
final_entities: 去重后的实体列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并详情列表,每个元素包含主实体名称和合并数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 直接使用保存的合并记录
|
||||||
|
if self.dedup_merge_records:
|
||||||
|
# 按合并数量排序,返回前几个
|
||||||
|
sorted_records = sorted(
|
||||||
|
self.dedup_merge_records,
|
||||||
|
key=lambda x: x.get("merged_count", 0),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_info = []
|
||||||
|
for record in sorted_records:
|
||||||
|
merge_info.append({
|
||||||
|
"main_entity_name": record.get("entity_name", "未知实体"),
|
||||||
|
"merged_count": record.get("merged_count", 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
return merge_info
|
||||||
|
|
||||||
|
# 如果没有保存的记录,返回空列表
|
||||||
|
logger.info("未找到实体合并记录")
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _analyze_entity_disambiguation(
|
||||||
|
self,
|
||||||
|
original_entities: List[ExtractedEntityNode],
|
||||||
|
final_entities: List[ExtractedEntityNode]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_entities: 去重前的实体列表
|
||||||
|
final_entities: 去重后的实体列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消歧详情列表,每个元素包含实体名称和消歧类型
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 直接使用保存的消歧记录
|
||||||
|
if self.dedup_disamb_records:
|
||||||
|
return self.dedup_disamb_records
|
||||||
|
|
||||||
|
# 如果没有保存的记录,返回空列表
|
||||||
|
logger.info("未找到实体消歧记录")
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_entity_type_display_name(self, entity_type: str) -> str:
|
||||||
|
"""
|
||||||
|
获取实体类型的中文显示名称
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_type: 英文实体类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
中文显示名称
|
||||||
|
"""
|
||||||
|
type_mapping = {
|
||||||
|
"Person": "人物实体节点",
|
||||||
|
"Organization": "组织实体节点",
|
||||||
|
"ORG": "组织实体节点",
|
||||||
|
"Location": "地点实体节点",
|
||||||
|
"LOC": "地点实体节点",
|
||||||
|
"Event": "事件实体节点",
|
||||||
|
"Concept": "概念实体节点",
|
||||||
|
"Time": "时间实体节点",
|
||||||
|
"Position": "职位实体节点",
|
||||||
|
"WorkRole": "职业实体节点",
|
||||||
|
"System": "系统实体节点",
|
||||||
|
"Policy": "政策实体节点",
|
||||||
|
"HistoricalPeriod": "历史时期实体节点",
|
||||||
|
"HistoricalState": "历史国家实体节点",
|
||||||
|
"HistoricalEvent": "历史事件实体节点",
|
||||||
|
"EconomicFactor": "经济因素实体节点",
|
||||||
|
"Condition": "条件实体节点",
|
||||||
|
"Numeric": "数值实体节点"
|
||||||
|
}
|
||||||
|
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||||
|
|
||||||
|
async def _output_relationship_creation_results(
|
||||||
|
self,
|
||||||
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
输出关系创建结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_entity_edges: 实体-实体边列表
|
||||||
|
entity_nodes: 实体节点列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 创建实体ID到名称的映射
|
||||||
|
entity_id_to_name = {node.id: node.name for node in entity_nodes}
|
||||||
|
|
||||||
|
# 输出关系创建结果
|
||||||
|
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
|
||||||
|
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
|
||||||
|
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
|
||||||
|
relation_type = edge.relation_type
|
||||||
|
|
||||||
|
relationship_result = {
|
||||||
|
"result_type": "relationship_creation",
|
||||||
|
"relationship_index": i + 1,
|
||||||
|
"source_entity": source_name,
|
||||||
|
"relation_type": relation_type,
|
||||||
|
"target_entity": target_name,
|
||||||
|
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _send_dedup_progress_callback(
|
||||||
|
self,
|
||||||
|
original_entities: int,
|
||||||
|
final_entities: int,
|
||||||
|
original_stmt_edges: int,
|
||||||
|
final_stmt_edges: int,
|
||||||
|
original_ent_edges: int,
|
||||||
|
final_ent_edges: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_entities: 去重前实体数量
|
||||||
|
final_entities: 去重后实体数量
|
||||||
|
original_stmt_edges: 去重前陈述句-实体边数量
|
||||||
|
final_stmt_edges: 去重后陈述句-实体边数量
|
||||||
|
original_ent_edges: 去重前实体-实体边数量
|
||||||
|
final_ent_edges: 去重后实体-实体边数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析去重消歧报告文件,获取具体的去重和消歧效果
|
||||||
|
dedup_details = await self._parse_dedup_report()
|
||||||
|
|
||||||
|
# 计算去重效果统计
|
||||||
|
entities_reduced = original_entities - final_entities
|
||||||
|
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
|
||||||
|
ent_edges_reduced = original_ent_edges - final_ent_edges
|
||||||
|
|
||||||
|
# 构建进度回调数据
|
||||||
|
dedup_stats = {
|
||||||
|
"entities": {
|
||||||
|
"original_count": original_entities,
|
||||||
|
"final_count": final_entities,
|
||||||
|
"reduced_count": entities_reduced,
|
||||||
|
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
||||||
|
},
|
||||||
|
"statement_entity_edges": {
|
||||||
|
"original_count": original_stmt_edges,
|
||||||
|
"final_count": final_stmt_edges,
|
||||||
|
"reduced_count": stmt_edges_reduced,
|
||||||
|
},
|
||||||
|
"entity_entity_edges": {
|
||||||
|
"original_count": original_ent_edges,
|
||||||
|
"final_count": final_ent_edges,
|
||||||
|
"reduced_count": ent_edges_reduced,
|
||||||
|
},
|
||||||
|
"dedup_examples": dedup_details.get("dedup_examples", []),
|
||||||
|
"disamb_examples": dedup_details.get("disamb_examples", []),
|
||||||
|
"summary": {
|
||||||
|
"total_merges": dedup_details.get("total_merges", 0),
|
||||||
|
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
|
||||||
|
# 即使解析失败,也发送基本的统计信息
|
||||||
|
try:
|
||||||
|
basic_stats = {
|
||||||
|
"entities": {
|
||||||
|
"original_count": original_entities,
|
||||||
|
"final_count": final_entities,
|
||||||
|
"reduced_count": original_entities - final_entities,
|
||||||
|
},
|
||||||
|
"summary": f"实体去重合并{original_entities - final_entities}个"
|
||||||
|
}
|
||||||
|
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error(f"发送基本去重统计失败: {e2}", exc_info=True)
|
||||||
|
|
||||||
|
async def _parse_dedup_report(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取去重消歧报告,直接使用内存中的记录(不再解析日志文件)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含去重和消歧详细信息的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 直接使用保存的记录构建报告
|
||||||
|
dedup_examples = []
|
||||||
|
disamb_examples = []
|
||||||
|
total_merges = 0
|
||||||
|
total_disambiguations = 0
|
||||||
|
|
||||||
|
# 处理合并记录
|
||||||
|
for record in self.dedup_merge_records:
|
||||||
|
merge_count = record.get("merged_count", 0)
|
||||||
|
total_merges += merge_count
|
||||||
|
|
||||||
|
dedup_examples.append({
|
||||||
|
"type": record.get("type", "未知"),
|
||||||
|
"entity_name": record.get("entity_name", "未知实体"),
|
||||||
|
"entity_type": record.get("entity_type", "未知类型"),
|
||||||
|
"merge_count": merge_count,
|
||||||
|
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理消歧记录
|
||||||
|
for record in self.dedup_disamb_records:
|
||||||
|
total_disambiguations += 1
|
||||||
|
|
||||||
|
# 从消歧类型中提取实体类型信息
|
||||||
|
disamb_type = record.get("disamb_type", "")
|
||||||
|
entity_name = record.get("entity_name", "未知实体")
|
||||||
|
|
||||||
|
disamb_examples.append({
|
||||||
|
"entity1_name": entity_name,
|
||||||
|
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
||||||
|
"entity2_name": entity_name,
|
||||||
|
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||||
|
"description": f"{entity_name},消歧区分成功"
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
|
||||||
|
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
|
||||||
|
"total_merges": total_merges,
|
||||||
|
"total_disambiguations": total_disambiguations,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取去重报告失败: {e}", exc_info=True)
|
||||||
|
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 数据加载和预处理函数
|
# 数据加载和预处理函数
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from app.db import Base
|
|||||||
class ApiKeyType(StrEnum):
|
class ApiKeyType(StrEnum):
|
||||||
"""API Key 类型"""
|
"""API Key 类型"""
|
||||||
AGENT = "agent" # 智能体
|
AGENT = "agent" # 智能体
|
||||||
CLUSTER = "cluster" # 集群
|
CLUSTER = "multi_agent" # 集群
|
||||||
WORKFLOW = "workflow" # 工作流
|
WORKFLOW = "workflow" # 工作流
|
||||||
SERVICE = "service" # 服务
|
SERVICE = "service" # 服务
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class ApiKeyRepository:
|
|||||||
"quota_used": api_key.quota_used,
|
"quota_used": api_key.quota_used,
|
||||||
"quota_limit": api_key.quota_limit,
|
"quota_limit": api_key.quota_limit,
|
||||||
"last_used_at": api_key.last_used_at,
|
"last_used_at": api_key.last_used_at,
|
||||||
|
"rate_limit": api_key.rate_limit,
|
||||||
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel):
|
|||||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def _normalize_data(cls, v):
|
def _normalize_data(cls, v):
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
d = v.get("data")
|
d = v.get("data")
|
||||||
@@ -60,6 +61,7 @@ class ConflictSchema(BaseModel):
|
|||||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def _normalize_data(cls, v):
|
def _normalize_data(cls, v):
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
d = v.get("data")
|
d = v.get("data")
|
||||||
@@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel):
|
|||||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def _normalize_resolved(cls, v):
|
def _normalize_resolved(cls, v):
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
conflict = v.get("conflict")
|
conflict = v.get("conflict")
|
||||||
@@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型
|
|||||||
|
|
||||||
|
|
||||||
def _now_ms() -> int:
|
def _now_ms() -> int:
|
||||||
return int(round(time.time() * 1000))
|
return round(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:
|
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:
|
||||||
|
|||||||
@@ -43,12 +43,13 @@ class ApiKeyService:
|
|||||||
existing = db.scalar(
|
existing = db.scalar(
|
||||||
select(ApiKey).where(
|
select(ApiKey).where(
|
||||||
ApiKey.workspace_id == workspace_id,
|
ApiKey.workspace_id == workspace_id,
|
||||||
|
ApiKey.resource_id == data.resource_id,
|
||||||
ApiKey.name == data.name,
|
ApiKey.name == data.name,
|
||||||
ApiKey.is_active
|
ApiKey.is_active
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||||
|
|
||||||
# 生成 API Key
|
# 生成 API Key
|
||||||
api_key = generate_api_key(data.type)
|
api_key = generate_api_key(data.type)
|
||||||
@@ -137,21 +138,19 @@ class ApiKeyService:
|
|||||||
"""更新 API Key配置"""
|
"""更新 API Key配置"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
|
||||||
|
|
||||||
# 检查名称重复
|
# 检查名称重复
|
||||||
if data.name and data.name != api_key.name:
|
if data.name and data.name != api_key.name:
|
||||||
existing = db.scalar(
|
existing = db.scalar(
|
||||||
select(ApiKey).where(
|
select(ApiKey).where(
|
||||||
ApiKey.workspace_id == workspace_id,
|
ApiKey.workspace_id == workspace_id,
|
||||||
|
ApiKey.resource_id == data.resource_id,
|
||||||
ApiKey.name == data.name,
|
ApiKey.name == data.name,
|
||||||
ApiKey.is_active,
|
ApiKey.is_active,
|
||||||
ApiKey.id != api_key_id
|
ApiKey.id != api_key_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||||
@@ -170,9 +169,6 @@ class ApiKeyService:
|
|||||||
"""删除 API Key"""
|
"""删除 API Key"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
|
||||||
|
|
||||||
ApiKeyRepository.delete(db, api_key_id)
|
ApiKeyRepository.delete(db, api_key_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@@ -188,9 +184,6 @@ class ApiKeyService:
|
|||||||
"""重新生成 API Key"""
|
"""重新生成 API Key"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
|
||||||
|
|
||||||
# 检查 API Key 是否激活
|
# 检查 API Key 是否激活
|
||||||
if not api_key.is_active:
|
if not api_key.is_active:
|
||||||
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
||||||
@@ -217,9 +210,6 @@ class ApiKeyService:
|
|||||||
"""获取使用统计"""
|
"""获取使用统计"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
|
||||||
|
|
||||||
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
||||||
return api_key_schema.ApiKeyStats(**stats_data)
|
return api_key_schema.ApiKeyStats(**stats_data)
|
||||||
|
|
||||||
@@ -236,9 +226,6 @@ class ApiKeyService:
|
|||||||
# 验证 API Key 权限
|
# 验证 API Key 权限
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
|
||||||
|
|
||||||
items, total = ApiKeyLogRepository.list_by_api_key(
|
items, total = ApiKeyLogRepository.list_by_api_key(
|
||||||
db, api_key_id, filters, page, pagesize
|
db, api_key_id, filters, page, pagesize
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,9 +4,12 @@ Memory Storage Service
|
|||||||
Handles business logic for memory storage operations.
|
Handles business logic for memory storage operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
|
|||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
|
from app.utils.sse_utils import format_sse_message
|
||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
ConfigFilter,
|
ConfigFilter,
|
||||||
ConfigPilotRun,
|
ConfigPilotRun,
|
||||||
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
return self._convert_timestamps_to_format(data_list)
|
return self._convert_timestamps_to_format(data_list)
|
||||||
|
|
||||||
|
|
||||||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
流式执行试运行,产生 SSE 格式的进度事件
|
||||||
支持 dialogue_text 参数用于试运行模式。
|
|
||||||
|
Args:
|
||||||
|
payload: 试运行配置和对话文本
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE 格式的字符串,包含以下事件类型:
|
||||||
|
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
|
||||||
|
- result: 最终结果
|
||||||
|
- error: 错误信息
|
||||||
|
- done: 完成标记
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当配置无效或参数缺失时
|
||||||
|
RuntimeError: 当管线执行失败时
|
||||||
"""
|
"""
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||||
|
|
||||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
try:
|
||||||
cid: Optional[str] = payload_cid if payload_cid else None
|
# 发出初始进度事件
|
||||||
|
yield format_sse_message("starting", {
|
||||||
|
"message": "开始试运行...",
|
||||||
|
"time": int(time.time() * 1000)
|
||||||
|
})
|
||||||
|
|
||||||
if not cid and os.path.isfile(dbrun_path):
|
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||||
try:
|
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
cid: Optional[str] = payload_cid if payload_cid else None
|
||||||
dbrun = json.load(f)
|
|
||||||
if isinstance(dbrun, dict):
|
|
||||||
sel = dbrun.get("selections", {})
|
|
||||||
if isinstance(sel, dict):
|
|
||||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
|
||||||
cid = fallback_cid or None
|
|
||||||
except Exception:
|
|
||||||
cid = None
|
|
||||||
|
|
||||||
if not cid:
|
if not cid and os.path.isfile(dbrun_path):
|
||||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
try:
|
||||||
|
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||||
|
dbrun = json.load(f)
|
||||||
|
if isinstance(dbrun, dict):
|
||||||
|
sel = dbrun.get("selections", {})
|
||||||
|
if isinstance(sel, dict):
|
||||||
|
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||||
|
cid = fallback_cid or None
|
||||||
|
except Exception:
|
||||||
|
cid = None
|
||||||
|
|
||||||
# 验证 dialogue_text 必须提供
|
if not cid:
|
||||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
|
||||||
if not dialogue_text:
|
|
||||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
|
||||||
|
|
||||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
# 验证 dialogue_text 必须提供
|
||||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||||
|
if not dialogue_text:
|
||||||
|
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||||
|
|
||||||
ok_override = reload_configuration_from_database(cid)
|
# 应用内存覆写并刷新常量
|
||||||
if not ok_override:
|
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
|
||||||
|
|
||||||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
ok_override = reload_configuration_from_database(cid)
|
||||||
from app.core.memory.main import main as pipeline_main
|
if not ok_override:
|
||||||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||||
|
|
||||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
# 使用队列在回调和生成器之间传递进度事件
|
||||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
# 调用自我反思
|
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||||
# data = [
|
"""
|
||||||
# {
|
进度回调函数,将进度事件放入队列
|
||||||
# "data": {
|
|
||||||
# "id": "1",
|
|
||||||
# "statement": "张明现在在谷歌工作。",
|
|
||||||
# "group_id": "1",
|
|
||||||
# "chunk_id": "10",
|
|
||||||
# "created_at": "2023-01-01",
|
|
||||||
# "expired_at": "2023-01-02",
|
|
||||||
# "valid_at": "2023-01-01",
|
|
||||||
# "invalid_at": "2023-01-02",
|
|
||||||
# "entity_ids": []
|
|
||||||
# },
|
|
||||||
# "conflict": True,
|
|
||||||
# "conflict_memory": {
|
|
||||||
# "id": "1",
|
|
||||||
# "statement": "张明现在在清华大学当讲师。",
|
|
||||||
# "group_id": "1",
|
|
||||||
# "chunk_id": "1",
|
|
||||||
# "created_at": "2019-12-01T19:15:05.213210",
|
|
||||||
# "expired_at": None,
|
|
||||||
# "valid_at": None,
|
|
||||||
# "invalid_at": None,
|
|
||||||
# "entity_ids": []
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
from app.core.memory.utils.config.get_example_data import get_example_data
|
|
||||||
data = get_example_data()
|
|
||||||
reflexion_result = await reflexion(data)
|
|
||||||
|
|
||||||
# 读取输出,使用全局配置路径
|
Args:
|
||||||
from app.core.config import settings
|
stage: 阶段标识
|
||||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
message: 进度消息
|
||||||
if not os.path.isfile(result_path):
|
data: 可选的结果数据(用于传递节点执行结果)
|
||||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
"""
|
||||||
|
await progress_queue.put((stage, message, data))
|
||||||
|
|
||||||
with open(result_path, "r", encoding="utf-8") as rf:
|
# 步骤 3: 在后台任务中执行管线
|
||||||
extracted_result = json.load(rf)
|
async def run_pipeline():
|
||||||
|
"""在后台执行管线并捕获异常"""
|
||||||
|
try:
|
||||||
|
from app.core.memory.main import main as pipeline_main
|
||||||
|
|
||||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||||
return {
|
await pipeline_main(
|
||||||
"config_id": cid,
|
dialogue_text=dialogue_text,
|
||||||
"time_log": os.path.join(project_root, "time.log"),
|
is_pilot_run=True,
|
||||||
"extracted_result": extracted_result,
|
progress_callback=progress_callback
|
||||||
}
|
)
|
||||||
|
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||||
|
|
||||||
|
# 标记管线完成
|
||||||
|
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
|
||||||
|
except Exception as e:
|
||||||
|
# 将异常放入队列
|
||||||
|
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
|
||||||
|
|
||||||
|
# 启动后台任务
|
||||||
|
pipeline_task = asyncio.create_task(run_pipeline())
|
||||||
|
|
||||||
|
# 步骤 4: 从队列中读取进度事件并发出
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 等待进度事件,设置超时以检测客户端断开
|
||||||
|
stage, message, data = await asyncio.wait_for(
|
||||||
|
progress_queue.get(),
|
||||||
|
timeout=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查特殊标记
|
||||||
|
if stage == "__PIPELINE_COMPLETE__":
|
||||||
|
break
|
||||||
|
elif stage == "__PIPELINE_ERROR__":
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
# 构建进度事件数据
|
||||||
|
progress_data = {
|
||||||
|
"message": message,
|
||||||
|
"time": int(time.time() * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有结果数据,添加到事件中
|
||||||
|
if data:
|
||||||
|
progress_data["data"] = data
|
||||||
|
|
||||||
|
# 发出进度事件,使用 stage 作为事件类型
|
||||||
|
yield format_sse_message(stage, progress_data)
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
# 超时,继续等待(这允许检测客户端断开)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 等待管线任务完成
|
||||||
|
await pipeline_task
|
||||||
|
|
||||||
|
# 步骤 5: 读取提取结果
|
||||||
|
from app.core.config import settings
|
||||||
|
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||||
|
if not os.path.isfile(result_path):
|
||||||
|
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||||
|
|
||||||
|
with open(result_path, "r", encoding="utf-8") as rf:
|
||||||
|
extracted_result = json.load(rf)
|
||||||
|
|
||||||
|
# 步骤 6: 发出结果事件
|
||||||
|
result_data = {
|
||||||
|
"config_id": cid,
|
||||||
|
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||||
|
"extracted_result": extracted_result,
|
||||||
|
}
|
||||||
|
yield format_sse_message("result", result_data)
|
||||||
|
|
||||||
|
# 步骤 7: 发出完成事件
|
||||||
|
yield format_sse_message("done", {
|
||||||
|
"message": "试运行完成",
|
||||||
|
"time": int(time.time() * 1000)
|
||||||
|
})
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# 客户端断开连接
|
||||||
|
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 发出错误事件
|
||||||
|
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
|
||||||
|
yield format_sse_message("error", {
|
||||||
|
"code": 5000,
|
||||||
|
"message": "试运行失败",
|
||||||
|
"error": str(e),
|
||||||
|
"time": int(time.time() * 1000)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||||
|
|||||||
27
api/app/utils/sse_utils.py
Normal file
27
api/app/utils/sse_utils.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
Server-Sent Events (SSE) Utility Functions
|
||||||
|
|
||||||
|
Provides shared utilities for formatting and handling SSE messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
def format_sse_message(event_type: str, data: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Format a message in Server-Sent Events (SSE) format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event (stage name, result, error, done)
|
||||||
|
data: Event data dictionary to be serialized as JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SSE formatted string: "event: <type>\\ndata: <json>\\n\\n"
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> format_sse_message("loading", {"message": "Loading..."})
|
||||||
|
'event: loading\\ndata: {"message": "Loading..."}\\n\\n'
|
||||||
|
"""
|
||||||
|
json_data = json.dumps(data, ensure_ascii=False)
|
||||||
|
return f"event: {event_type}\ndata: {json_data}\n\n"
|
||||||
Reference in New Issue
Block a user