[MODIFY] MEM SEE OUTPUT
This commit is contained in:
@@ -4,9 +4,12 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
||||
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
||||
支持 dialogue_text 参数用于试运行模式。
|
||||
流式执行试运行,产生 SSE 格式的进度事件
|
||||
|
||||
Args:
|
||||
payload: 试运行配置和对话文本
|
||||
|
||||
Yields:
|
||||
SSE 格式的字符串,包含以下事件类型:
|
||||
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
|
||||
- result: 最终结果
|
||||
- error: 错误信息
|
||||
- done: 完成标记
|
||||
|
||||
Raises:
|
||||
ValueError: 当配置无效或参数缺失时
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
yield format_sse_message("starting", {
|
||||
"message": "开始试运行...",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
||||
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
# "data": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在谷歌工作。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "10",
|
||||
# "created_at": "2023-01-01",
|
||||
# "expired_at": "2023-01-02",
|
||||
# "valid_at": "2023-01-01",
|
||||
# "invalid_at": "2023-01-02",
|
||||
# "entity_ids": []
|
||||
# },
|
||||
# "conflict": True,
|
||||
# "conflict_memory": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在清华大学当讲师。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "1",
|
||||
# "created_at": "2019-12-01T19:15:05.213210",
|
||||
# "expired_at": None,
|
||||
# "valid_at": None,
|
||||
# "invalid_at": None,
|
||||
# "entity_ids": []
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
from app.core.memory.utils.config.get_example_data import get_example_data
|
||||
data = get_example_data()
|
||||
reflexion_result = await reflexion(data)
|
||||
|
||||
# 读取输出,使用全局配置路径
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
进度回调函数,将进度事件放入队列
|
||||
|
||||
Args:
|
||||
stage: 阶段标识
|
||||
message: 进度消息
|
||||
data: 可选的结果数据(用于传递节点执行结果)
|
||||
"""
|
||||
await progress_queue.put((stage, message, data))
|
||||
|
||||
# 步骤 3: 在后台任务中执行管线
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
# 标记管线完成
|
||||
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
|
||||
except Exception as e:
|
||||
# 将异常放入队列
|
||||
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
|
||||
|
||||
# 启动后台任务
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
|
||||
# 步骤 4: 从队列中读取进度事件并发出
|
||||
while True:
|
||||
try:
|
||||
# 等待进度事件,设置超时以检测客户端断开
|
||||
stage, message, data = await asyncio.wait_for(
|
||||
progress_queue.get(),
|
||||
timeout=0.5
|
||||
)
|
||||
|
||||
# 检查特殊标记
|
||||
if stage == "__PIPELINE_COMPLETE__":
|
||||
break
|
||||
elif stage == "__PIPELINE_ERROR__":
|
||||
raise RuntimeError(message)
|
||||
|
||||
# 构建进度事件数据
|
||||
progress_data = {
|
||||
"message": message,
|
||||
"time": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
# 如果有结果数据,添加到事件中
|
||||
if data:
|
||||
progress_data["data"] = data
|
||||
|
||||
# 发出进度事件,使用 stage 作为事件类型
|
||||
yield format_sse_message(stage, progress_data)
|
||||
|
||||
except TimeoutError:
|
||||
# 超时,继续等待(这允许检测客户端断开)
|
||||
continue
|
||||
|
||||
# 等待管线任务完成
|
||||
await pipeline_task
|
||||
|
||||
# 步骤 5: 读取提取结果
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
# 步骤 6: 发出结果事件
|
||||
result_data = {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
yield format_sse_message("result", result_data)
|
||||
|
||||
# 步骤 7: 发出完成事件
|
||||
yield format_sse_message("done", {
|
||||
"message": "试运行完成",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 客户端断开连接
|
||||
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
|
||||
raise
|
||||
except Exception as e:
|
||||
# 发出错误事件
|
||||
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
|
||||
yield format_sse_message("error", {
|
||||
"code": 5000,
|
||||
"message": "试运行失败",
|
||||
"error": str(e),
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
|
||||
Reference in New Issue
Block a user