[MODIFY] MEM SEE OUTPUT
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from fastapi import APIRouter, Depends, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
@@ -322,36 +323,24 @@ def read_all_config(
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> StreamingResponse:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
# 先尝试从数据库加载配置
|
||||
try:
|
||||
config_loaded = reload_configuration_from_database(str(payload.config_id))
|
||||
if not config_loaded:
|
||||
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
|
||||
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Exception while loading configuration: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = await svc.pilot_run(payload)
|
||||
return success(data=result, msg="试运行完成")
|
||||
except ValueError as e:
|
||||
# 捕获参数验证错误
|
||||
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Pilot run failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
|
||||
@@ -21,7 +21,7 @@ def generate_api_key(key_type: ApiKeyType) -> str:
|
||||
# 前缀映射
|
||||
prefix_map = {
|
||||
ApiKeyType.AGENT: "sk-agent-",
|
||||
ApiKeyType.CLUSTER: "sk-cluster-",
|
||||
ApiKeyType.CLUSTER: "sk-multi_agent-",
|
||||
ApiKeyType.WORKFLOW: "sk-workflow-",
|
||||
ApiKeyType.SERVICE: "sk-service-"
|
||||
}
|
||||
|
||||
@@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_dedup_details,
|
||||
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"selections": {
|
||||
"config_id": "1"
|
||||
"config_id": ""
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
@@ -50,7 +50,11 @@ logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
async def main(
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本
|
||||
|
||||
@@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
@@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
@@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
@@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
@@ -170,6 +209,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
@@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
@@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
|
||||
# 进度回调:正在知识抽取
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
@@ -216,6 +282,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 进度回调:生成结果
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
@@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges(
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
) -> Tuple[
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
Dict[str, Any] # 新增:返回详细的去重消歧记录
|
||||
]:
|
||||
"""
|
||||
主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
@@ -981,8 +986,18 @@ async def deduplicate_entities_and_edges(
|
||||
append=report_append,
|
||||
stage_notes=report_stage_notes,
|
||||
)
|
||||
|
||||
# 构建详细的去重消歧记录(用于内存访问,避免解析日志文件)
|
||||
dedup_details = {
|
||||
"exact_merge_map": exact_merge_map,
|
||||
"fuzzy_merge_records": fuzzy_merge_records,
|
||||
"llm_decision_records": local_llm_records,
|
||||
"disamb_records": disamb_records,
|
||||
"id_redirect": id_redirect,
|
||||
"blocked_pairs": blocked_pairs,
|
||||
}
|
||||
|
||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
|
||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details
|
||||
|
||||
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
|
||||
def _write_dedup_fusion_report(
|
||||
|
||||
@@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
@@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
@@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_chunk_edges,
|
||||
fused_statement_entity_edges,
|
||||
fused_entity_entity_edges,
|
||||
dedup_details, # 返回去重详情
|
||||
)
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
5. 提供错误处理和日志记录
|
||||
6. 支持试运行模式(不写入数据库)
|
||||
|
||||
作者:Memory Refactoring Team
|
||||
作者:
|
||||
日期:2025-11-21
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -103,12 +105,21 @@ class ExtractionOrchestrator:
|
||||
embedder_client: 嵌入模型客户端
|
||||
connector: Neo4j 连接器
|
||||
config: 流水线配置,如果为 None 则使用默认配置
|
||||
progress_callback: 进度回调函数
|
||||
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
self.connector = connector
|
||||
self.config = config or ExtractionPipelineConfig()
|
||||
self.is_pilot_run = False # 默认非试运行模式
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
|
||||
self.id_redirect_map: Dict[str, str] = {} # ID重定向映射
|
||||
|
||||
# 初始化各个提取器
|
||||
self.statement_extractor = StatementExtractor(
|
||||
@@ -160,6 +171,13 @@ class ExtractionOrchestrator:
|
||||
# 步骤 1: 陈述句提取
|
||||
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
|
||||
dialog_data_list = await self._extract_statements(dialog_data_list)
|
||||
|
||||
# 收集陈述句内容和统计数量
|
||||
all_statements_list = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||
@@ -170,11 +188,90 @@ class ExtractionOrchestrator:
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
) = await self._parallel_extract_and_embed(dialog_data_list)
|
||||
|
||||
# 收集实体和三元组内容,并统计数量
|
||||
all_entities_list = []
|
||||
all_triplets_list = []
|
||||
for triplet_map in triplet_maps:
|
||||
for triplet_info in triplet_map.values():
|
||||
if triplet_info:
|
||||
all_entities_list.extend(triplet_info.entities)
|
||||
all_triplets_list.extend(triplet_info.triplets)
|
||||
|
||||
total_entities = len(all_entities_list)
|
||||
total_triplets = len(all_triplets_list)
|
||||
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
|
||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
|
||||
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||
if self.progress_callback:
|
||||
# 第一阶段:陈述句提取结果
|
||||
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement_index": i + 1,
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||
|
||||
# 第二阶段:三元组提取结果
|
||||
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||
triplet_result = {
|
||||
"extraction_type": "triplet",
|
||||
"triplet_index": i + 1,
|
||||
"subject": triplet.subject_name,
|
||||
"predicate": triplet.predicate,
|
||||
"object": triplet.object_name
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||
|
||||
# 第三阶段:时间提取结果
|
||||
if total_temporal > 0:
|
||||
# 收集时间信息
|
||||
temporal_results = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
temporal_results.append({
|
||||
"statement_id": statement.id,
|
||||
"statement": statement.statement,
|
||||
"valid_at": statement.temporal_validity.valid_at,
|
||||
"invalid_at": statement.temporal_validity.invalid_at
|
||||
})
|
||||
|
||||
# 输出时间提取结果
|
||||
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": i + 1,
|
||||
"statement": temporal_result["statement"],
|
||||
"valid_at": temporal_result["valid_at"],
|
||||
"invalid_at": temporal_result["invalid_at"]
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
else:
|
||||
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": 0,
|
||||
"message": "未发现时间信息"
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
|
||||
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": total_entities,
|
||||
"triplets_count": total_triplets,
|
||||
"temporal_ranges_count": total_temporal,
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 步骤 4: 将提取的数据赋值到语句
|
||||
logger.info("步骤 4/6: 数据赋值")
|
||||
dialog_data_list = await self._assign_extracted_data(
|
||||
@@ -218,6 +315,8 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return result
|
||||
|
||||
@@ -732,6 +831,10 @@ class ExtractionOrchestrator:
|
||||
包含所有节点和边的元组
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
# 进度回调:正在创建节点和边
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
|
||||
dialogue_nodes = []
|
||||
chunk_nodes = []
|
||||
@@ -904,6 +1007,23 @@ class ExtractionOrchestrator:
|
||||
f"陈述句-实体边: {len(statement_entity_edges)}, "
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:只输出关系创建结果
|
||||
if self.progress_callback:
|
||||
# 输出关系创建结果
|
||||
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
"statement_nodes_count": len(statement_nodes),
|
||||
"entity_nodes_count": len(entity_nodes),
|
||||
"statement_chunk_edges_count": len(statement_chunk_edges),
|
||||
"statement_entity_edges_count": len(statement_entity_edges),
|
||||
"entity_entity_edges_count": len(entity_entity_edges),
|
||||
}
|
||||
await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
@@ -950,6 +1070,11 @@ class ExtractionOrchestrator:
|
||||
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
# 进度回调:正在去重消歧
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
logger.info(
|
||||
f"去重前: {len(entity_nodes)} 个实体节点, "
|
||||
f"{len(statement_entity_edges)} 条陈述句-实体边, "
|
||||
@@ -963,7 +1088,7 @@ class ExtractionOrchestrator:
|
||||
# 只执行第一层去重
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
@@ -972,6 +1097,9 @@ class ExtractionOrchestrator:
|
||||
dedup_config=self.config.deduplication,
|
||||
)
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
|
||||
|
||||
result_tuple = (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -1009,7 +1137,11 @@ class ExtractionOrchestrator:
|
||||
_,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = result_tuple
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||
|
||||
logger.info(
|
||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||
@@ -1021,6 +1153,46 @@ class ExtractionOrchestrator:
|
||||
f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, "
|
||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:输出去重消歧的具体结果
|
||||
if self.progress_callback:
|
||||
# 分析实体合并情况
|
||||
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出去重合并的实体示例
|
||||
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": merge_detail["main_entity_name"],
|
||||
"merged_count": merge_detail["merged_count"],
|
||||
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||
|
||||
# 分析实体消歧情况
|
||||
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出实体消歧的结果
|
||||
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
"disambiguated_entity_name": disamb_detail["entity_name"],
|
||||
"disambiguation_type": disamb_detail["disamb_type"],
|
||||
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||
"reason": disamb_detail.get("reason", ""),
|
||||
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||
|
||||
|
||||
|
||||
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||
await self._send_dedup_progress_callback(
|
||||
len(entity_nodes), len(final_entity_nodes),
|
||||
len(statement_entity_edges), len(final_statement_entity_edges),
|
||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||
)
|
||||
|
||||
|
||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||
try:
|
||||
@@ -1041,6 +1213,378 @@ class ExtractionOrchestrator:
|
||||
logger.error(f"两阶段去重失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _save_dedup_details(
|
||||
self,
|
||||
dedup_details: Dict[str, Any],
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||
|
||||
Args:
|
||||
dedup_details: 去重函数返回的详细记录
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
"""
|
||||
try:
|
||||
# 保存ID重定向映射
|
||||
self.id_redirect_map = dedup_details.get("id_redirect", {})
|
||||
|
||||
# 处理精确匹配的合并记录
|
||||
exact_merge_map = dedup_details.get("exact_merge_map", {})
|
||||
for key, info in exact_merge_map.items():
|
||||
merged_ids = info.get("merged_ids", set())
|
||||
if merged_ids:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "精确匹配",
|
||||
"canonical_id": info.get("canonical_id"),
|
||||
"entity_name": info.get("name"),
|
||||
"entity_type": info.get("entity_type"),
|
||||
"merged_count": len(merged_ids),
|
||||
"merged_ids": list(merged_ids)
|
||||
})
|
||||
|
||||
# 处理模糊匹配的合并记录
|
||||
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
|
||||
for record in fuzzy_merge_records:
|
||||
# 解析模糊匹配记录字符串
|
||||
# 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..."
|
||||
try:
|
||||
import re
|
||||
match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record)
|
||||
if match:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "模糊匹配",
|
||||
"canonical_id": match.group(1),
|
||||
"entity_name": match.group(3),
|
||||
"entity_type": match.group(4),
|
||||
"merged_count": 1,
|
||||
"merged_ids": [match.group(5)]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
|
||||
|
||||
# 处理LLM去重的合并记录
|
||||
llm_decision_records = dedup_details.get("llm_decision_records", [])
|
||||
for record in llm_decision_records:
|
||||
if "[LLM去重]" in str(record):
|
||||
try:
|
||||
import re
|
||||
# 格式: "[LLM去重] 同名类型相似 name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||
match = re.search(r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))", record)
|
||||
if match:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "LLM去重",
|
||||
"entity_name": match.group(1),
|
||||
"entity_type": f"{match.group(2)}|{match.group(4)}",
|
||||
"merged_count": 1,
|
||||
"merged_ids": []
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
|
||||
|
||||
# 处理消歧记录
|
||||
disamb_records = dedup_details.get("disamb_records", [])
|
||||
for record in disamb_records:
|
||||
if "[DISAMB阻断]" in str(record):
|
||||
try:
|
||||
import re
|
||||
# 格式: "[DISAMB阻断] name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||
content = str(record).replace("[DISAMB阻断]", "").strip()
|
||||
match = re.search(r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content)
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
entity2_name = match.group(3).strip()
|
||||
entity2_type = match.group(4)
|
||||
|
||||
# 提取置信度和原因
|
||||
conf_match = re.search(r"conf=([0-9.]+)", str(record))
|
||||
confidence = conf_match.group(1) if conf_match else "unknown"
|
||||
|
||||
reason_match = re.search(r"reason=([^|]+)", str(record))
|
||||
reason = reason_match.group(1).strip() if reason_match else ""
|
||||
|
||||
self.dedup_disamb_records.append({
|
||||
"entity_name": entity1_name,
|
||||
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
|
||||
"confidence": confidence,
|
||||
"reason": reason[:100] + "..." if len(reason) > 100 else reason
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||
|
||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||
|
||||
async def _analyze_entity_merges(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||
|
||||
Args:
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
|
||||
Returns:
|
||||
合并详情列表,每个元素包含主实体名称和合并数量
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的合并记录
|
||||
if self.dedup_merge_records:
|
||||
# 按合并数量排序,返回前几个
|
||||
sorted_records = sorted(
|
||||
self.dedup_merge_records,
|
||||
key=lambda x: x.get("merged_count", 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
merge_info = []
|
||||
for record in sorted_records:
|
||||
merge_info.append({
|
||||
"main_entity_name": record.get("entity_name", "未知实体"),
|
||||
"merged_count": record.get("merged_count", 1)
|
||||
})
|
||||
|
||||
return merge_info
|
||||
|
||||
# 如果没有保存的记录,返回空列表
|
||||
logger.info("未找到实体合并记录")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _analyze_entity_disambiguation(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||
|
||||
Args:
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
|
||||
Returns:
|
||||
消歧详情列表,每个元素包含实体名称和消歧类型
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的消歧记录
|
||||
if self.dedup_disamb_records:
|
||||
return self.dedup_disamb_records
|
||||
|
||||
# 如果没有保存的记录,返回空列表
|
||||
logger.info("未找到实体消歧记录")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _get_entity_type_display_name(self, entity_type: str) -> str:
|
||||
"""
|
||||
获取实体类型的中文显示名称
|
||||
|
||||
Args:
|
||||
entity_type: 英文实体类型
|
||||
|
||||
Returns:
|
||||
中文显示名称
|
||||
"""
|
||||
type_mapping = {
|
||||
"Person": "人物实体节点",
|
||||
"Organization": "组织实体节点",
|
||||
"ORG": "组织实体节点",
|
||||
"Location": "地点实体节点",
|
||||
"LOC": "地点实体节点",
|
||||
"Event": "事件实体节点",
|
||||
"Concept": "概念实体节点",
|
||||
"Time": "时间实体节点",
|
||||
"Position": "职位实体节点",
|
||||
"WorkRole": "职业实体节点",
|
||||
"System": "系统实体节点",
|
||||
"Policy": "政策实体节点",
|
||||
"HistoricalPeriod": "历史时期实体节点",
|
||||
"HistoricalState": "历史国家实体节点",
|
||||
"HistoricalEvent": "历史事件实体节点",
|
||||
"EconomicFactor": "经济因素实体节点",
|
||||
"Condition": "条件实体节点",
|
||||
"Numeric": "数值实体节点"
|
||||
}
|
||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||
|
||||
async def _output_relationship_creation_results(
|
||||
self,
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
输出关系创建结果
|
||||
|
||||
Args:
|
||||
entity_entity_edges: 实体-实体边列表
|
||||
entity_nodes: 实体节点列表
|
||||
"""
|
||||
try:
|
||||
# 创建实体ID到名称的映射
|
||||
entity_id_to_name = {node.id: node.name for node in entity_nodes}
|
||||
|
||||
# 输出关系创建结果
|
||||
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
|
||||
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
|
||||
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
|
||||
relation_type = edge.relation_type
|
||||
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": i + 1,
|
||||
"source_entity": source_name,
|
||||
"relation_type": relation_type,
|
||||
"target_entity": target_name,
|
||||
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
|
||||
}
|
||||
|
||||
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||
|
||||
async def _send_dedup_progress_callback(
|
||||
self,
|
||||
original_entities: int,
|
||||
final_entities: int,
|
||||
original_stmt_edges: int,
|
||||
final_stmt_edges: int,
|
||||
original_ent_edges: int,
|
||||
final_ent_edges: int,
|
||||
):
|
||||
"""
|
||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||
|
||||
Args:
|
||||
original_entities: 去重前实体数量
|
||||
final_entities: 去重后实体数量
|
||||
original_stmt_edges: 去重前陈述句-实体边数量
|
||||
final_stmt_edges: 去重后陈述句-实体边数量
|
||||
original_ent_edges: 去重前实体-实体边数量
|
||||
final_ent_edges: 去重后实体-实体边数量
|
||||
"""
|
||||
try:
|
||||
# 解析去重消歧报告文件,获取具体的去重和消歧效果
|
||||
dedup_details = await self._parse_dedup_report()
|
||||
|
||||
# 计算去重效果统计
|
||||
entities_reduced = original_entities - final_entities
|
||||
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
|
||||
ent_edges_reduced = original_ent_edges - final_ent_edges
|
||||
|
||||
# 构建进度回调数据
|
||||
dedup_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": entities_reduced,
|
||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
||||
},
|
||||
"statement_entity_edges": {
|
||||
"original_count": original_stmt_edges,
|
||||
"final_count": final_stmt_edges,
|
||||
"reduced_count": stmt_edges_reduced,
|
||||
},
|
||||
"entity_entity_edges": {
|
||||
"original_count": original_ent_edges,
|
||||
"final_count": final_ent_edges,
|
||||
"reduced_count": ent_edges_reduced,
|
||||
},
|
||||
"dedup_examples": dedup_details.get("dedup_examples", []),
|
||||
"disamb_examples": dedup_details.get("disamb_examples", []),
|
||||
"summary": {
|
||||
"total_merges": dedup_details.get("total_merges", 0),
|
||||
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
|
||||
}
|
||||
}
|
||||
|
||||
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
|
||||
# 即使解析失败,也发送基本的统计信息
|
||||
try:
|
||||
basic_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": original_entities - final_entities,
|
||||
},
|
||||
"summary": f"实体去重合并{original_entities - final_entities}个"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
|
||||
except Exception as e2:
|
||||
logger.error(f"发送基本去重统计失败: {e2}", exc_info=True)
|
||||
|
||||
async def _parse_dedup_report(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取去重消歧报告,直接使用内存中的记录(不再解析日志文件)
|
||||
|
||||
Returns:
|
||||
包含去重和消歧详细信息的字典
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的记录构建报告
|
||||
dedup_examples = []
|
||||
disamb_examples = []
|
||||
total_merges = 0
|
||||
total_disambiguations = 0
|
||||
|
||||
# 处理合并记录
|
||||
for record in self.dedup_merge_records:
|
||||
merge_count = record.get("merged_count", 0)
|
||||
total_merges += merge_count
|
||||
|
||||
dedup_examples.append({
|
||||
"type": record.get("type", "未知"),
|
||||
"entity_name": record.get("entity_name", "未知实体"),
|
||||
"entity_type": record.get("entity_type", "未知类型"),
|
||||
"merge_count": merge_count,
|
||||
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个"
|
||||
})
|
||||
|
||||
# 处理消歧记录
|
||||
for record in self.dedup_disamb_records:
|
||||
total_disambiguations += 1
|
||||
|
||||
# 从消歧类型中提取实体类型信息
|
||||
disamb_type = record.get("disamb_type", "")
|
||||
entity_name = record.get("entity_name", "未知实体")
|
||||
|
||||
disamb_examples.append({
|
||||
"entity1_name": entity_name,
|
||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
||||
"entity2_name": entity_name,
|
||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||
"description": f"{entity_name},消歧区分成功"
|
||||
})
|
||||
|
||||
return {
|
||||
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
|
||||
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
|
||||
"total_merges": total_merges,
|
||||
"total_disambiguations": total_disambiguations,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取去重报告失败: {e}", exc_info=True)
|
||||
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 数据加载和预处理函数
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.db import Base
|
||||
class ApiKeyType(StrEnum):
|
||||
"""API Key 类型"""
|
||||
AGENT = "agent" # 智能体
|
||||
CLUSTER = "cluster" # 集群
|
||||
CLUSTER = "multi_agent" # 集群
|
||||
WORKFLOW = "workflow" # 工作流
|
||||
SERVICE = "service" # 服务
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModelConfig(Base):
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
# 关联关系
|
||||
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")
|
||||
|
||||
@@ -126,6 +126,7 @@ class ApiKeyRepository:
|
||||
"quota_used": api_key.quota_used,
|
||||
"quota_limit": api_key.quota_limit,
|
||||
"last_used_at": api_key.last_used_at,
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -60,6 +61,7 @@ class ConflictSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel):
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
@@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(round(time.time() * 1000))
|
||||
return round(time.time() * 1000)
|
||||
|
||||
|
||||
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:
|
||||
|
||||
@@ -43,12 +43,13 @@ class ApiKeyService:
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.resource_id == data.resource_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
# 生成 API Key
|
||||
api_key = generate_api_key(data.type)
|
||||
@@ -137,21 +138,19 @@ class ApiKeyService:
|
||||
"""更新 API Key配置"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
# 检查名称重复
|
||||
if data.name and data.name != api_key.name:
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.resource_id == data.resource_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active,
|
||||
ApiKey.id != api_key_id
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||
@@ -170,9 +169,6 @@ class ApiKeyService:
|
||||
"""删除 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
ApiKeyRepository.delete(db, api_key_id)
|
||||
db.commit()
|
||||
|
||||
@@ -188,9 +184,6 @@ class ApiKeyService:
|
||||
"""重新生成 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
# 检查 API Key 是否激活
|
||||
if not api_key.is_active:
|
||||
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
||||
@@ -217,9 +210,6 @@ class ApiKeyService:
|
||||
"""获取使用统计"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
||||
return api_key_schema.ApiKeyStats(**stats_data)
|
||||
|
||||
@@ -236,9 +226,6 @@ class ApiKeyService:
|
||||
# 验证 API Key 权限
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
items, total = ApiKeyLogRepository.list_by_api_key(
|
||||
db, api_key_id, filters, page, pagesize
|
||||
)
|
||||
|
||||
@@ -4,9 +4,12 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
||||
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
||||
支持 dialogue_text 参数用于试运行模式。
|
||||
流式执行试运行,产生 SSE 格式的进度事件
|
||||
|
||||
Args:
|
||||
payload: 试运行配置和对话文本
|
||||
|
||||
Yields:
|
||||
SSE 格式的字符串,包含以下事件类型:
|
||||
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
|
||||
- result: 最终结果
|
||||
- error: 错误信息
|
||||
- done: 完成标记
|
||||
|
||||
Raises:
|
||||
ValueError: 当配置无效或参数缺失时
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
yield format_sse_message("starting", {
|
||||
"message": "开始试运行...",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
||||
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
# "data": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在谷歌工作。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "10",
|
||||
# "created_at": "2023-01-01",
|
||||
# "expired_at": "2023-01-02",
|
||||
# "valid_at": "2023-01-01",
|
||||
# "invalid_at": "2023-01-02",
|
||||
# "entity_ids": []
|
||||
# },
|
||||
# "conflict": True,
|
||||
# "conflict_memory": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在清华大学当讲师。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "1",
|
||||
# "created_at": "2019-12-01T19:15:05.213210",
|
||||
# "expired_at": None,
|
||||
# "valid_at": None,
|
||||
# "invalid_at": None,
|
||||
# "entity_ids": []
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
from app.core.memory.utils.config.get_example_data import get_example_data
|
||||
data = get_example_data()
|
||||
reflexion_result = await reflexion(data)
|
||||
|
||||
# 读取输出,使用全局配置路径
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
进度回调函数,将进度事件放入队列
|
||||
|
||||
Args:
|
||||
stage: 阶段标识
|
||||
message: 进度消息
|
||||
data: 可选的结果数据(用于传递节点执行结果)
|
||||
"""
|
||||
await progress_queue.put((stage, message, data))
|
||||
|
||||
# 步骤 3: 在后台任务中执行管线
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
# 标记管线完成
|
||||
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
|
||||
except Exception as e:
|
||||
# 将异常放入队列
|
||||
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
|
||||
|
||||
# 启动后台任务
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
|
||||
# 步骤 4: 从队列中读取进度事件并发出
|
||||
while True:
|
||||
try:
|
||||
# 等待进度事件,设置超时以检测客户端断开
|
||||
stage, message, data = await asyncio.wait_for(
|
||||
progress_queue.get(),
|
||||
timeout=0.5
|
||||
)
|
||||
|
||||
# 检查特殊标记
|
||||
if stage == "__PIPELINE_COMPLETE__":
|
||||
break
|
||||
elif stage == "__PIPELINE_ERROR__":
|
||||
raise RuntimeError(message)
|
||||
|
||||
# 构建进度事件数据
|
||||
progress_data = {
|
||||
"message": message,
|
||||
"time": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
# 如果有结果数据,添加到事件中
|
||||
if data:
|
||||
progress_data["data"] = data
|
||||
|
||||
# 发出进度事件,使用 stage 作为事件类型
|
||||
yield format_sse_message(stage, progress_data)
|
||||
|
||||
except TimeoutError:
|
||||
# 超时,继续等待(这允许检测客户端断开)
|
||||
continue
|
||||
|
||||
# 等待管线任务完成
|
||||
await pipeline_task
|
||||
|
||||
# 步骤 5: 读取提取结果
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
# 步骤 6: 发出结果事件
|
||||
result_data = {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
yield format_sse_message("result", result_data)
|
||||
|
||||
# 步骤 7: 发出完成事件
|
||||
yield format_sse_message("done", {
|
||||
"message": "试运行完成",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 客户端断开连接
|
||||
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
|
||||
raise
|
||||
except Exception as e:
|
||||
# 发出错误事件
|
||||
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
|
||||
yield format_sse_message("error", {
|
||||
"code": 5000,
|
||||
"message": "试运行失败",
|
||||
"error": str(e),
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
|
||||
27
api/app/utils/sse_utils.py
Normal file
27
api/app/utils/sse_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Server-Sent Events (SSE) Utility Functions
|
||||
|
||||
Provides shared utilities for formatting and handling SSE messages.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def format_sse_message(event_type: str, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a message in Server-Sent Events (SSE) format.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (stage name, result, error, done)
|
||||
data: Event data dictionary to be serialized as JSON
|
||||
|
||||
Returns:
|
||||
SSE formatted string: "event: <type>\\ndata: <json>\\n\\n"
|
||||
|
||||
Example:
|
||||
>>> format_sse_message("loading", {"message": "Loading..."})
|
||||
'event: loading\\ndata: {"message": "Loading..."}\\n\\n'
|
||||
"""
|
||||
json_data = json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event_type}\ndata: {json_data}\n\n"
|
||||
Reference in New Issue
Block a user