fix(workflow): fix activation and branch control issues in streaming output
This commit is contained in:
@@ -11,16 +11,12 @@ from typing import Any
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.graph_builder import GraphBuilder
|
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||||
|
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.core.workflow.template_renderer import render_template
|
||||||
# from app.core.tools.registry import ToolRegistry
|
|
||||||
# from app.core.tools.executor import ToolExecutor
|
|
||||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
|
||||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
|
||||||
# from app.db import get_db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,6 +51,8 @@ class WorkflowExecutor:
|
|||||||
self.execution_config = workflow_config.get("execution_config", {})
|
self.execution_config = workflow_config.get("execution_config", {})
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id = None
|
||||||
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
|
self.activate_end: str | None = None
|
||||||
|
|
||||||
self.checkpoint_config = RunnableConfig(
|
self.checkpoint_config = RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
@@ -127,7 +125,6 @@ class WorkflowExecutor:
|
|||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"error": None,
|
"error": None,
|
||||||
"error_node": None,
|
"error_node": None,
|
||||||
"streaming_buffer": {}, # 流式缓冲区
|
|
||||||
"cycle_nodes": [
|
"cycle_nodes": [
|
||||||
node.get("id")
|
node.get("id")
|
||||||
for node in self.workflow_config.get("nodes")
|
for node in self.workflow_config.get("nodes")
|
||||||
@@ -139,9 +136,8 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_final_output(self, result, elapsed_time):
|
def _build_final_output(self, result, elapsed_time, final_output):
|
||||||
node_outputs = result.get("node_outputs", {})
|
node_outputs = result.get("node_outputs", {})
|
||||||
final_output = self._extract_final_output(node_outputs)
|
|
||||||
token_usage = self._aggregate_token_usage(node_outputs)
|
token_usage = self._aggregate_token_usage(node_outputs)
|
||||||
conversation_id = None
|
conversation_id = None
|
||||||
for node_id, node_output in node_outputs.items():
|
for node_id, node_output in node_outputs.items():
|
||||||
@@ -161,6 +157,12 @@ class WorkflowExecutor:
|
|||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _update_end_activate(self, node_id):
|
||||||
|
for node in self.end_outputs.keys():
|
||||||
|
self.end_outputs[node].update_activate(node_id)
|
||||||
|
if self.end_outputs[node].activate and self.activate_end is None:
|
||||||
|
self.activate_end = node
|
||||||
|
|
||||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -173,6 +175,7 @@ class WorkflowExecutor:
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
|
self.end_outputs = builder.end_node_map
|
||||||
graph = builder.build()
|
graph = builder.build()
|
||||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||||
|
|
||||||
@@ -205,14 +208,34 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
|
|
||||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
|
full_content = ''
|
||||||
|
for end_info in self.end_outputs.values():
|
||||||
|
output_template = "".join([output.literal for output in end_info.outputs])
|
||||||
|
full_content += render_template(
|
||||||
|
output_template,
|
||||||
|
result.get("variables", {}),
|
||||||
|
result.get("runtime_vars", {}),
|
||||||
|
strict=False
|
||||||
|
)
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
return self._build_final_output(result, elapsed_time)
|
return self._build_final_output(result, elapsed_time, full_content)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算耗时(即使失败也记录)
|
# 计算耗时(即使失败也记录)
|
||||||
@@ -273,6 +296,7 @@ class WorkflowExecutor:
|
|||||||
# 3. Execute workflow
|
# 3. Execute workflow
|
||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
full_content = ''
|
||||||
|
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
@@ -293,21 +317,25 @@ class WorkflowExecutor:
|
|||||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||||
if event_type in ("message", "node_chunk"):
|
if event_type == "node_chunk":
|
||||||
|
node_id = data.get("node_id")
|
||||||
|
if self.activate_end:
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
current_output = end_info.outputs[end_info.cursor]
|
||||||
|
if current_output.is_variable and node_id in current_output.literal:
|
||||||
|
if data.get("done"):
|
||||||
|
end_info.cursor += 1
|
||||||
|
else:
|
||||||
|
full_content += data.get("chunk")
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": data.get("chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
yield {
|
|
||||||
"event": event_type, # "message" or "node_chunk"
|
|
||||||
"data": {
|
|
||||||
"node_id": data.get("node_id"),
|
|
||||||
"chunk": data.get("chunk"),
|
|
||||||
"full_content": data.get("full_content"),
|
|
||||||
"chunk_index": data.get("chunk_index"),
|
|
||||||
"is_prefix": data.get("is_prefix"),
|
|
||||||
"is_suffix": data.get("is_suffix"),
|
|
||||||
"conversation_id": input_data.get("conversation_id"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elif event_type == "node_error":
|
elif event_type == "node_error":
|
||||||
yield {
|
yield {
|
||||||
"event": event_type, # "message" or "node_chunk"
|
"event": event_type, # "message" or "node_chunk"
|
||||||
@@ -376,14 +404,107 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
# TODO:流式输出点
|
for node_id in data.keys():
|
||||||
|
self._update_end_activate(node_id)
|
||||||
|
wait = False
|
||||||
|
state = graph.get_state(config=self.checkpoint_config)
|
||||||
|
node_outputs = state.values.get("runtime_vars", {})
|
||||||
|
for _ in data.keys():
|
||||||
|
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
||||||
|
|
||||||
|
while self.activate_end and not wait:
|
||||||
|
message = ''
|
||||||
|
logger.info(self.activate_end)
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
content = end_info.outputs[end_info.cursor]
|
||||||
|
while content.activate:
|
||||||
|
if not content.is_variable:
|
||||||
|
full_content += content.literal
|
||||||
|
message += content.literal
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
content.literal,
|
||||||
|
variables={},
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
message += chunk
|
||||||
|
full_content += chunk
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor == len(end_info.outputs):
|
||||||
|
break
|
||||||
|
content = end_info.outputs[end_info.cursor]
|
||||||
|
if end_info.cursor != len(end_info.outputs):
|
||||||
|
wait = True
|
||||||
|
else:
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
for node_id in data.keys():
|
||||||
|
self._update_end_activate(node_id)
|
||||||
|
if message:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
|
while self.activate_end:
|
||||||
|
message = ''
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
content = end_info.outputs[end_info.cursor]
|
||||||
|
if not content.is_variable:
|
||||||
|
message += content.literal
|
||||||
|
else:
|
||||||
|
node_outputs = result.get("runtime_vars", {})
|
||||||
|
variables = result.get("variables", {})
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
content.literal,
|
||||||
|
variables=variables,
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
message += chunk
|
||||||
|
full_content += chunk
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor == len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
if self.end_outputs:
|
||||||
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
if message:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
result = graph.get_state(self.checkpoint_config).values
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
|
logger.info(result)
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Workflow execution completed (streaming), "
|
f"Workflow execution completed (streaming), "
|
||||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||||
@@ -392,7 +513,7 @@ class WorkflowExecutor:
|
|||||||
# 发送 workflow_end 事件
|
# 发送 workflow_end 事件
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self._build_final_output(result, elapsed_time)
|
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -414,31 +535,6 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
|
||||||
"""从节点输出中提取最终输出
|
|
||||||
|
|
||||||
优先级:
|
|
||||||
1. 最后一个执行的非 start/end 节点的 output
|
|
||||||
2. 如果没有节点输出,返回 None
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_outputs: 所有节点的输出
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最终输出字符串或 None
|
|
||||||
"""
|
|
||||||
if not node_outputs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取最后一个节点的输出
|
|
||||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
|
||||||
|
|
||||||
if last_node_output and isinstance(last_node_output, dict):
|
|
||||||
return last_node_output.get("output")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||||
"""聚合所有节点的 token 使用情况
|
"""聚合所有节点的 token 使用情况
|
||||||
@@ -529,178 +625,3 @@ async def execute_workflow_stream(
|
|||||||
)
|
)
|
||||||
async for event in executor.execute_stream(input_data):
|
async for event in executor.execute_stream(input_data):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# ==================== 工具管理系统集成 ====================
|
|
||||||
|
|
||||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
|
||||||
# """获取工作流可用的工具列表
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# workspace_id: 工作空间ID
|
|
||||||
# user_id: 用户ID
|
|
||||||
#
|
|
||||||
# Returns:
|
|
||||||
# 可用工具列表
|
|
||||||
# """
|
|
||||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
|
||||||
# logger.warning("工具管理系统不可用")
|
|
||||||
# return []
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# db = next(get_db())
|
|
||||||
#
|
|
||||||
# # 创建工具注册表
|
|
||||||
# registry = ToolRegistry(db)
|
|
||||||
#
|
|
||||||
# # 注册内置工具类
|
|
||||||
# from app.core.tools.builtin import (
|
|
||||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
|
||||||
# )
|
|
||||||
# registry.register_tool_class(DateTimeTool)
|
|
||||||
# registry.register_tool_class(JsonTool)
|
|
||||||
# registry.register_tool_class(BaiduSearchTool)
|
|
||||||
# registry.register_tool_class(MinerUTool)
|
|
||||||
# registry.register_tool_class(TextInTool)
|
|
||||||
#
|
|
||||||
# # 获取活跃的工具
|
|
||||||
# import uuid
|
|
||||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
|
||||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
|
||||||
#
|
|
||||||
# # 转换为Langchain工具
|
|
||||||
# langchain_tools = []
|
|
||||||
# for tool_info in active_tools:
|
|
||||||
# try:
|
|
||||||
# tool_instance = registry.get_tool(tool_info.id)
|
|
||||||
# if tool_instance:
|
|
||||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
|
||||||
# langchain_tools.append(langchain_tool)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
|
||||||
#
|
|
||||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
|
||||||
# return langchain_tools
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"获取工作流工具失败: {e}")
|
|
||||||
# return []
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class ToolWorkflowNode:
|
|
||||||
# """工具工作流节点 - 在工作流中执行工具"""
|
|
||||||
#
|
|
||||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
|
||||||
# """初始化工具节点
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# node_config: 节点配置
|
|
||||||
# workflow_config: 工作流配置
|
|
||||||
# """
|
|
||||||
# self.node_config = node_config
|
|
||||||
# self.workflow_config = workflow_config
|
|
||||||
# self.tool_id = node_config.get("tool_id")
|
|
||||||
# self.tool_parameters = node_config.get("parameters", {})
|
|
||||||
#
|
|
||||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
|
||||||
# """执行工具节点"""
|
|
||||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
|
||||||
# logger.error("工具管理系统不可用")
|
|
||||||
# state["error"] = "工具管理系统不可用"
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# from sqlalchemy.orm import Session
|
|
||||||
# db = next(get_db())
|
|
||||||
#
|
|
||||||
# # 创建工具执行器
|
|
||||||
# registry = ToolRegistry(db)
|
|
||||||
# executor = ToolExecutor(db, registry)
|
|
||||||
#
|
|
||||||
# # 准备参数(支持变量替换)
|
|
||||||
# parameters = self._prepare_parameters(state)
|
|
||||||
#
|
|
||||||
# # 执行工具
|
|
||||||
# result = await executor.execute_tool(
|
|
||||||
# tool_id=self.tool_id,
|
|
||||||
# parameters=parameters,
|
|
||||||
# user_id=uuid.UUID(state["user_id"]),
|
|
||||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# # 更新状态
|
|
||||||
# node_id = self.node_config.get("id")
|
|
||||||
# if result.success:
|
|
||||||
# state["node_outputs"][node_id] = {
|
|
||||||
# "type": "tool",
|
|
||||||
# "tool_id": self.tool_id,
|
|
||||||
# "output": result.data,
|
|
||||||
# "execution_time": result.execution_time,
|
|
||||||
# "token_usage": result.token_usage
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# # 更新运行时变量
|
|
||||||
# if isinstance(result.data, dict):
|
|
||||||
# for key, value in result.data.items():
|
|
||||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
|
||||||
# else:
|
|
||||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
|
||||||
# else:
|
|
||||||
# state["error"] = result.error
|
|
||||||
# state["error_node"] = node_id
|
|
||||||
# state["node_outputs"][node_id] = {
|
|
||||||
# "type": "tool",
|
|
||||||
# "tool_id": self.tool_id,
|
|
||||||
# "error": result.error,
|
|
||||||
# "execution_time": result.execution_time
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"工具节点执行失败: {e}")
|
|
||||||
# state["error"] = str(e)
|
|
||||||
# state["error_node"] = self.node_config.get("id")
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
|
||||||
# """准备工具参数(支持变量替换)"""
|
|
||||||
# parameters = {}
|
|
||||||
#
|
|
||||||
# for key, value in self.tool_parameters.items():
|
|
||||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
|
||||||
# # 变量替换
|
|
||||||
# var_path = value[2:-1]
|
|
||||||
#
|
|
||||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
|
||||||
# if "." in var_path:
|
|
||||||
# parts = var_path.split(".")
|
|
||||||
# current = state.get("variables", {})
|
|
||||||
#
|
|
||||||
# for part in parts:
|
|
||||||
# if isinstance(current, dict) and part in current:
|
|
||||||
# current = current[part]
|
|
||||||
# else:
|
|
||||||
# # 尝试从运行时变量获取
|
|
||||||
# runtime_key = ".".join(parts)
|
|
||||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# parameters[key] = current
|
|
||||||
# else:
|
|
||||||
# # 简单变量
|
|
||||||
# variables = state.get("variables", {})
|
|
||||||
# parameters[key] = variables.get(var_path, value)
|
|
||||||
# else:
|
|
||||||
# parameters[key] = value
|
|
||||||
#
|
|
||||||
# return parameters
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# # 注册工具节点到NodeFactory(如果存在)
|
|
||||||
# try:
|
|
||||||
# from app.core.workflow.nodes import NodeFactory
|
|
||||||
# if hasattr(NodeFactory, 'register_node_type'):
|
|
||||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
|
||||||
# logger.info("工具节点已注册到工作流系统")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.warning(f"注册工具节点失败: {e}")
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||||
from langgraph.types import Send
|
from langgraph.types import Send
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
@@ -15,6 +18,115 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputContent(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single output segment of an End node.
|
||||||
|
|
||||||
|
An output segment can be either:
|
||||||
|
- literal text (static string)
|
||||||
|
- a variable placeholder (e.g. {{ node.field }})
|
||||||
|
|
||||||
|
Each segment has its own activation state, which is especially
|
||||||
|
important in stream mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
literal: str = Field(
|
||||||
|
...,
|
||||||
|
description="Raw output content. Can be literal text or a variable placeholder."
|
||||||
|
)
|
||||||
|
|
||||||
|
activate: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether this output segment is currently active.\n"
|
||||||
|
"- True: allowed to be emitted/output\n"
|
||||||
|
"- False: blocked until activated by branch control"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_variable: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether this segment represents a variable placeholder.\n"
|
||||||
|
"True -> variable (e.g. {{ node.field }})\n"
|
||||||
|
"False -> literal text"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOutputConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Streaming output configuration for an End node.
|
||||||
|
|
||||||
|
This structure controls:
|
||||||
|
- whether the End node output is globally active
|
||||||
|
- which upstream branch nodes are responsible for activation
|
||||||
|
- how each output segment behaves in streaming mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
activate: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Global activation state of the End node output.\n"
|
||||||
|
"If False, no output should be emitted until all control nodes are resolved."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_nodes: list[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"List of upstream branch node IDs that control this End node.\n"
|
||||||
|
"Each node must signal completion before output becomes active."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs: list[OutputContent] = Field(
|
||||||
|
...,
|
||||||
|
description="Ordered list of output segments parsed from the output template."
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor: int = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Streaming cursor index.\n"
|
||||||
|
"Indicates how many output segments have already been emitted."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_activate(self, node_id):
|
||||||
|
"""
|
||||||
|
Update activation state based on an upstream node completion.
|
||||||
|
|
||||||
|
This method is typically called when a branch/control node finishes execution.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
1. If the node is a control node:
|
||||||
|
- Remove it from `control_nodes`
|
||||||
|
- If all control nodes are resolved, activate the entire output
|
||||||
|
|
||||||
|
2. Activate variable output segments that depend on this node:
|
||||||
|
- If an output segment is a variable
|
||||||
|
- And its literal references the completed node_id
|
||||||
|
- Mark that segment as active
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Case 1: resolve control branch dependency
|
||||||
|
if node_id in self.control_nodes:
|
||||||
|
self.control_nodes.remove(node_id)
|
||||||
|
|
||||||
|
# All branch constraints resolved → enable output
|
||||||
|
if not self.control_nodes:
|
||||||
|
self.activate = True
|
||||||
|
|
||||||
|
# Case 2: activate variable segments related to this node
|
||||||
|
for i in range(len(self.outputs)):
|
||||||
|
if (
|
||||||
|
self.outputs[i].is_variable
|
||||||
|
and node_id in self.outputs[i].literal
|
||||||
|
):
|
||||||
|
self.outputs[i].activate = True
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -29,6 +141,12 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id = None
|
||||||
self.end_node_ids = []
|
self.end_node_ids = []
|
||||||
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
|
self._find_upstream_branch_node = lru_cache(
|
||||||
|
maxsize=len(self.nodes) * 2
|
||||||
|
)(self._find_upstream_branche_node)
|
||||||
|
self._analyze_end_node_output()
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
@@ -43,79 +161,182 @@ class GraphBuilder:
|
|||||||
def edges(self) -> list[dict[str, Any]]:
|
def edges(self) -> list[dict[str, Any]]:
|
||||||
return self.workflow_config.get("edges", [])
|
return self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
def get_node_type(self, node_id: str) -> str:
|
||||||
"""
|
"""Retrieve the type of node given its ID.
|
||||||
Analyze the prefix configuration for End nodes.
|
|
||||||
|
|
||||||
This function scans each End node's output template, identifies
|
Args:
|
||||||
references to its direct upstream nodes, and extracts the prefix
|
node_id (str): The unique identifier of the node.
|
||||||
string appearing before the first reference.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple:
|
str: The type of the node.
|
||||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
|
||||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
Raises:
|
||||||
|
RuntimeError: If no node with the given `node_id` exists.
|
||||||
"""
|
"""
|
||||||
import re
|
try:
|
||||||
|
return self.node_map[node_id]["type"]
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||||
|
|
||||||
prefixes = {}
|
def _find_upstream_branche_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
||||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
"""Find upstream branch nodes for a given target node in the workflow graph.
|
||||||
|
|
||||||
# 找到所有 End 节点
|
This method identifies all upstream control (branch) nodes that can affect
|
||||||
|
the execution of `target_node`. If `target_node` is reachable from a start
|
||||||
|
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
||||||
|
|
||||||
|
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
||||||
|
and non-branch nodes, recursively traversing upstream through non-branch
|
||||||
|
nodes. If any non-branch upstream path does not lead to a branch node,
|
||||||
|
the result will indicate that no valid upstream branch node exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_node (str): The identifier of the target node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, tuple[str]]:
|
||||||
|
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
||||||
|
one branch node; False if any path reaches a start node without a branch.
|
||||||
|
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
||||||
|
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
||||||
|
"""
|
||||||
|
source_nodes = [
|
||||||
|
edge.get("source")
|
||||||
|
for edge in self.edges
|
||||||
|
if edge.get("target") == target_node
|
||||||
|
]
|
||||||
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
return False, tuple()
|
||||||
|
|
||||||
|
branch_nodes = []
|
||||||
|
non_branch_nodes = []
|
||||||
|
|
||||||
|
for node_id in source_nodes:
|
||||||
|
if self.get_node_type(node_id) in BRANCH_NODES:
|
||||||
|
branch_nodes.append(node_id)
|
||||||
|
else:
|
||||||
|
non_branch_nodes.append(node_id)
|
||||||
|
|
||||||
|
has_branch = True
|
||||||
|
for node_id in non_branch_nodes:
|
||||||
|
node_has_branch, nodes = self._find_upstream_branche_node(node_id)
|
||||||
|
has_branch = has_branch and node_has_branch
|
||||||
|
if not has_branch:
|
||||||
|
break
|
||||||
|
branch_nodes.extend(nodes)
|
||||||
|
if not has_branch:
|
||||||
|
branch_nodes = []
|
||||||
|
|
||||||
|
return has_branch, tuple(set(branch_nodes))
|
||||||
|
|
||||||
|
def _analyze_end_node_output(self):
|
||||||
|
"""
|
||||||
|
Analyze output templates of all End nodes and generate StreamOutputConfig.
|
||||||
|
|
||||||
|
This method is responsible for parsing the `output` field of End nodes,
|
||||||
|
splitting literal text and variable placeholders (e.g. {{ node.field }}),
|
||||||
|
and determining whether each output segment should be activated immediately
|
||||||
|
or controlled by upstream branch nodes.
|
||||||
|
|
||||||
|
In stream mode:
|
||||||
|
- If the End node is controlled by any upstream branch node, the output
|
||||||
|
will be initially inactive and controlled by those branch nodes.
|
||||||
|
- Otherwise, the output is activated immediately.
|
||||||
|
|
||||||
|
In non-stream mode:
|
||||||
|
- All outputs are activated by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||||
|
|
||||||
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
output_template = end_node.get("config", {}).get("output")
|
config = end_node.get("config", {})
|
||||||
|
output = config.get("output")
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
# Skip End nodes without output configuration
|
||||||
|
if not output:
|
||||||
if not output_template:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Find all node references in the template
|
# Regex to split output into:
|
||||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
# - variable placeholders: {{ ... }}
|
||||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
# - normal literal text
|
||||||
matches = list(re.finditer(pattern, output_template))
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
variable_pattern = re.compile(variable_pattern_string)
|
||||||
|
|
||||||
# Identify all direct upstream nodes connected to the End node
|
# Split output into ordered segments
|
||||||
direct_upstream_nodes = []
|
output_template = list(re.findall(pattern, output))
|
||||||
for edge in self.edges:
|
|
||||||
if edge.get("target") == end_node_id:
|
|
||||||
source_node_id = edge.get("source")
|
|
||||||
direct_upstream_nodes.append(source_node_id)
|
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
# Determine whether each segment is literal text
|
||||||
|
# True -> literal (can be directly output)
|
||||||
|
# False -> variable placeholder (needs runtime value)
|
||||||
|
output_flag = [
|
||||||
|
not bool(variable_pattern.match(item))
|
||||||
|
for item in output_template
|
||||||
|
]
|
||||||
|
|
||||||
# 找到第一个直接上游节点的引用
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
for match in matches:
|
if self.stream:
|
||||||
referenced_node_id = match.group(1)
|
# Find upstream branch nodes that can control this End node
|
||||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
has_branch, control_nodes = self._find_upstream_branche_node(end_node_id)
|
||||||
|
|
||||||
if referenced_node_id in direct_upstream_nodes:
|
# Build StreamOutputConfig for this End node
|
||||||
# 这是直接上游节点的引用,提取前缀
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
prefix = output_template[:match.start()]
|
# If there is no upstream branch, output is active immediately
|
||||||
|
activate=not has_branch,
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] "
|
# Branch nodes that control activation of this End node
|
||||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
control_nodes=list(control_nodes),
|
||||||
|
|
||||||
# 标记这个节点为"相邻且被引用"
|
# Convert output segments into OutputContent objects
|
||||||
adjacent_and_referenced.add(referenced_node_id)
|
outputs=list(
|
||||||
|
[
|
||||||
|
OutputContent(
|
||||||
|
literal=output_string,
|
||||||
|
# Literal text can be activated immediately unless blocked by branch
|
||||||
|
activate=activate,
|
||||||
|
# Variable segments are marked explicitly
|
||||||
|
is_variable=not activate
|
||||||
|
)
|
||||||
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
# Cursor for streaming output (initially 0)
|
||||||
|
cursor=0
|
||||||
|
)
|
||||||
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
|
f"activate: {not has_branch}, "
|
||||||
|
f"control_nodes: {control_nodes},"
|
||||||
|
f"output: {output_template},"
|
||||||
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
if prefix:
|
# Non-stream mode: all outputs are activated by default
|
||||||
prefixes[referenced_node_id] = prefix
|
else:
|
||||||
logger.info(f"[Prefix Analysis] "
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
activate=True,
|
||||||
|
control_nodes=[],
|
||||||
# 只处理第一个直接上游节点的引用
|
outputs=list(
|
||||||
break
|
[
|
||||||
|
OutputContent(
|
||||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
literal=output_string,
|
||||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
activate=True,
|
||||||
return prefixes, adjacent_and_referenced
|
is_variable=not activate
|
||||||
|
)
|
||||||
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
cursor=0
|
||||||
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
"""Add all nodes from the workflow configuration to the state graph.
|
"""Add all nodes from the workflow configuration to the state graph.
|
||||||
@@ -135,9 +356,6 @@ class GraphBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Analyze End node prefixes if in stream mode
|
|
||||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
|
||||||
|
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
@@ -171,17 +389,6 @@ class GraphBuilder:
|
|||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Inject End node prefix configuration if in stream mode
|
|
||||||
if self.stream and node_id in end_prefixes:
|
|
||||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
|
||||||
logger.info(f"Injected End prefix for node {node_id}")
|
|
||||||
|
|
||||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
|
||||||
if self.stream:
|
|
||||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
|
||||||
if node_id in adjacent_and_referenced:
|
|
||||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
|
||||||
|
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Stream mode: create an async generator function
|
# Stream mode: create an async generator function
|
||||||
@@ -261,6 +468,7 @@ class GraphBuilder:
|
|||||||
for source_node, branches in conditional_edges.items():
|
for source_node, branches in conditional_edges.items():
|
||||||
def make_router(src, branch_list):
|
def make_router(src, branch_list):
|
||||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||||
|
|
||||||
def make_branch_node(node_name, targets):
|
def make_branch_node(node_name, targets):
|
||||||
def node(s):
|
def node(s):
|
||||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||||
|
|||||||
@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
|
|||||||
error: str | None
|
error: str | None
|
||||||
error_node: str | None
|
error_node: str | None
|
||||||
|
|
||||||
# Streaming buffer (stores real-time streaming output of nodes)
|
|
||||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
|
||||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
|
||||||
|
|
||||||
# node activate status
|
# node activate status
|
||||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||||
|
|
||||||
@@ -300,7 +296,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.check_activate(state):
|
if not self.check_activate(state):
|
||||||
yield self.trans_activate(state)
|
yield self.trans_activate(state)
|
||||||
logger.info(f"跳过节点{self.node_id}")
|
logger.info(f"jump node: {self.node_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -313,19 +309,6 @@ class BaseNode(ABC):
|
|||||||
# Get LangGraph's stream writer for sending custom data
|
# Get LangGraph's stream writer for sending custom data
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
|
|
||||||
# Check if this is an End node
|
|
||||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
|
||||||
is_end_node = self.node_type == "end"
|
|
||||||
|
|
||||||
# Check if this node is adjacent to End node (for message type)
|
|
||||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
|
||||||
|
|
||||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
|
||||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
|
||||||
|
|
||||||
# Accumulate complete result (for final wrapping)
|
# Accumulate complete result (for final wrapping)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_result = None
|
final_result = None
|
||||||
@@ -340,66 +323,25 @@ class BaseNode(ABC):
|
|||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
# Check if it's a completion marker
|
# Check if it's a completion marker
|
||||||
if isinstance(item, dict) and item.get("__final__"):
|
if item.get("__final__"):
|
||||||
final_result = item["result"]
|
final_result = item["result"]
|
||||||
elif isinstance(item, str):
|
else:
|
||||||
# String is a chunk
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
chunks.append(item)
|
content = str(item.get("chunk"))
|
||||||
full_content = "".join(chunks)
|
done = item.get("done", False)
|
||||||
|
chunks.append(content)
|
||||||
|
|
||||||
# Send chunks for all nodes (including End nodes for suffix)
|
# Send chunks for all nodes (including End nodes for suffix)
|
||||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||||
|
|
||||||
# 1. Send via stream writer (for real-time client updates)
|
# 1. Send via stream writer (for real-time client updates)
|
||||||
writer({
|
writer({
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
"type": "node_chunk",
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": item,
|
"chunk": content,
|
||||||
"full_content": full_content,
|
"done": done
|
||||||
"chunk_index": chunk_count
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# 2. Update streaming buffer in state (for downstream nodes)
|
|
||||||
# Only non-End nodes need streaming buffer
|
|
||||||
if not is_end_node:
|
|
||||||
yield {
|
|
||||||
"streaming_buffer": {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Other types are also treated as chunks
|
|
||||||
chunk_count += 1
|
|
||||||
chunk_str = str(item)
|
|
||||||
chunks.append(chunk_str)
|
|
||||||
full_content = "".join(chunks)
|
|
||||||
|
|
||||||
# Send chunks for all nodes
|
|
||||||
writer({
|
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"chunk": chunk_str,
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_index": chunk_count
|
|
||||||
})
|
|
||||||
|
|
||||||
# Only non-End nodes need streaming buffer
|
|
||||||
if not is_end_node:
|
|
||||||
yield {
|
|
||||||
"streaming_buffer": {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||||
@@ -426,16 +368,6 @@ class BaseNode(ABC):
|
|||||||
"looping": state["looping"]
|
"looping": state["looping"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add streaming buffer for non-End nodes
|
|
||||||
if not is_end_node:
|
|
||||||
state_update["streaming_buffer"] = {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": "".join(chunks),
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": True # Mark as complete
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Finally yield state update
|
# Finally yield state update
|
||||||
# LangGraph will merge this into state
|
# LangGraph will merge this into state
|
||||||
yield state_update | self.trans_activate(state)
|
yield state_update | self.trans_activate(state)
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ End 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,24 +35,8 @@ class EndNode(BaseNode):
|
|||||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
if output_template:
|
if output_template:
|
||||||
output = self._render_template(output_template, state, strict=False)
|
output = self._render_template(output_template, state, strict=False)
|
||||||
state['messages'].extend([
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.get_variable("sys.message", state)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": output
|
|
||||||
}
|
|
||||||
])
|
|
||||||
else:
|
else:
|
||||||
state['messages'].extend([
|
output = ""
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.get_variable("sys.message", state)
|
|
||||||
},
|
|
||||||
])
|
|
||||||
output = "工作流已完成"
|
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
# 统计信息(用于日志)
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
@@ -63,274 +45,3 @@ class EndNode(BaseNode):
|
|||||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
|
||||||
"""从模板中提取引用的节点 ID
|
|
||||||
|
|
||||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: 模板字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
引用的节点 ID 列表
|
|
||||||
"""
|
|
||||||
# 匹配 {{node_id.xxx}} 格式
|
|
||||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
|
||||||
matches = re.findall(pattern, template)
|
|
||||||
return list(set(matches)) # 去重
|
|
||||||
|
|
||||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
|
||||||
"""解析模板,分离静态文本和动态引用
|
|
||||||
|
|
||||||
例如:'你好 {{llm.output}}, 这是后缀'
|
|
||||||
返回:[
|
|
||||||
{"type": "static", "content": "你好 "},
|
|
||||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
|
||||||
{"type": "static", "content": ", 这是后缀"}
|
|
||||||
]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: 模板字符串
|
|
||||||
state: 工作流状态
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
模板部分列表
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
last_end = 0
|
|
||||||
|
|
||||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
|
||||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
|
||||||
|
|
||||||
for match in re.finditer(pattern, template):
|
|
||||||
start, end = match.span()
|
|
||||||
|
|
||||||
# 添加前面的静态文本
|
|
||||||
if start > last_end:
|
|
||||||
static_text = template[last_end:start]
|
|
||||||
if static_text:
|
|
||||||
parts.append({"type": "static", "content": static_text})
|
|
||||||
|
|
||||||
# 解析动态引用
|
|
||||||
ref = match.group(1).strip()
|
|
||||||
|
|
||||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
|
||||||
if '.' in ref:
|
|
||||||
node_id, field = ref.split('.', 1)
|
|
||||||
parts.append({
|
|
||||||
"type": "dynamic",
|
|
||||||
"node_id": node_id,
|
|
||||||
"field": field,
|
|
||||||
"raw": ref
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
|
||||||
# 直接渲染这部分
|
|
||||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
|
||||||
parts.append({"type": "static", "content": rendered})
|
|
||||||
|
|
||||||
last_end = end
|
|
||||||
|
|
||||||
# 添加最后的静态文本
|
|
||||||
if last_end < len(template):
|
|
||||||
static_text = template[last_end:]
|
|
||||||
if static_text:
|
|
||||||
parts.append({"type": "static", "content": static_text})
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
|
||||||
"""Execute End node business logic (streaming)
|
|
||||||
|
|
||||||
Smart output strategy:
|
|
||||||
1. Check if template references a direct upstream LLM node
|
|
||||||
2. If yes, only output the part AFTER that reference (suffix)
|
|
||||||
3. Prefix and LLM content have already been sent during LLM node streaming
|
|
||||||
|
|
||||||
Note: Only LLM nodes get this special treatment. Other node types output normally.
|
|
||||||
|
|
||||||
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
|
||||||
- Direct upstream LLM node is llm_qa
|
|
||||||
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
|
|
||||||
- LLM content was streamed during LLM node execution
|
|
||||||
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Workflow state
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Completion marker
|
|
||||||
"""
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
|
||||||
|
|
||||||
# 获取配置的输出模板
|
|
||||||
output_template = self.config.get("output")
|
|
||||||
|
|
||||||
if not output_template:
|
|
||||||
output = "工作流已完成"
|
|
||||||
from langgraph.config import get_stream_writer
|
|
||||||
writer = get_stream_writer()
|
|
||||||
writer({
|
|
||||||
"type": "message", # End node output uses message type
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"chunk": "",
|
|
||||||
"full_content": output,
|
|
||||||
"chunk_index": 1,
|
|
||||||
"is_suffix": False
|
|
||||||
})
|
|
||||||
state['messages'].extend([
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.get_variable("sys.message", state)
|
|
||||||
}
|
|
||||||
])
|
|
||||||
yield {"__final__": True, "result": output}
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find direct upstream LLM nodes
|
|
||||||
direct_upstream_llm_nodes = []
|
|
||||||
for edge in self.workflow_config.get("edges", []):
|
|
||||||
if edge.get("target") == self.node_id:
|
|
||||||
source_node_id = edge.get("source")
|
|
||||||
# Check if the source node is an LLM node
|
|
||||||
for node in self.workflow_config.get("nodes", []):
|
|
||||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
|
||||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
|
||||||
direct_upstream_llm_nodes.append(source_node_id)
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
|
|
||||||
|
|
||||||
# Parse template parts
|
|
||||||
parts = self._parse_template_parts(output_template, state)
|
|
||||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
|
||||||
|
|
||||||
# Find the first reference to a direct upstream LLM node
|
|
||||||
upstream_llm_ref_index = None
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
|
|
||||||
upstream_llm_ref_index = i
|
|
||||||
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if upstream_llm_ref_index is None:
|
|
||||||
# No reference to direct upstream LLM node, output complete template content
|
|
||||||
output = self._render_template(output_template, state, strict=False)
|
|
||||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
|
||||||
|
|
||||||
# Send complete content via writer (as a single message chunk)
|
|
||||||
from langgraph.config import get_stream_writer
|
|
||||||
writer = get_stream_writer()
|
|
||||||
writer({
|
|
||||||
"type": "message", # End node output uses message type
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"chunk": output,
|
|
||||||
"full_content": output,
|
|
||||||
"chunk_index": 1,
|
|
||||||
"is_suffix": False
|
|
||||||
})
|
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
|
||||||
|
|
||||||
state['messages'].extend([
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.get_variable("sys.message", state)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": output
|
|
||||||
}
|
|
||||||
])
|
|
||||||
|
|
||||||
# yield completion marker
|
|
||||||
yield {"__final__": True, "result": output}
|
|
||||||
return
|
|
||||||
|
|
||||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
|
||||||
logger.info(
|
|
||||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
|
||||||
|
|
||||||
# Collect suffix parts
|
|
||||||
suffix_parts = []
|
|
||||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}")
|
|
||||||
for i in range(upstream_llm_ref_index + 1, len(parts)):
|
|
||||||
part = parts[i]
|
|
||||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
|
||||||
if part["type"] == "static":
|
|
||||||
# 静态文本
|
|
||||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
|
||||||
suffix_parts.append(part["content"])
|
|
||||||
|
|
||||||
elif part["type"] == "dynamic":
|
|
||||||
# Other dynamic references (if there are multiple references)
|
|
||||||
node_id = part["node_id"]
|
|
||||||
field = part["field"]
|
|
||||||
|
|
||||||
# Use VariablePool to get variable value
|
|
||||||
pool = self.get_variable_pool(state)
|
|
||||||
try:
|
|
||||||
# Try to get variable value with default empty string
|
|
||||||
content = pool.get([node_id, field], default="")
|
|
||||||
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
|
||||||
content = ""
|
|
||||||
|
|
||||||
# Convert to string if not None
|
|
||||||
suffix_parts.append(str(content) if content is not None else "")
|
|
||||||
|
|
||||||
# 拼接后缀
|
|
||||||
suffix = "".join(suffix_parts)
|
|
||||||
|
|
||||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
|
||||||
full_output = self._render_template(output_template, state, strict=False)
|
|
||||||
|
|
||||||
state['messages'].extend([
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.get_variable("sys.message", state)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_output
|
|
||||||
}
|
|
||||||
])
|
|
||||||
|
|
||||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
|
||||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
|
||||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
|
||||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
|
||||||
|
|
||||||
if suffix:
|
|
||||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
|
||||||
# 一次性输出后缀(作为单个 chunk)
|
|
||||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
|
||||||
# 而是通过 writer 直接发送
|
|
||||||
from langgraph.config import get_stream_writer
|
|
||||||
writer = get_stream_writer()
|
|
||||||
writer({
|
|
||||||
"type": "message", # End 节点的输出使用 message 类型
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"chunk": suffix,
|
|
||||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
|
||||||
"chunk_index": 1,
|
|
||||||
"is_suffix": True
|
|
||||||
})
|
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
|
||||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
node_outputs = state.get("node_outputs", {})
|
|
||||||
total_nodes = len(node_outputs)
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
|
||||||
|
|
||||||
# yield 完成标记(包含完整输出)
|
|
||||||
yield {"__final__": True, "result": full_output}
|
|
||||||
|
|||||||
@@ -7,18 +7,18 @@ LLM 节点实现
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -231,42 +231,14 @@ class LLMNode(BaseNode):
|
|||||||
文本片段(chunk)或完成标记
|
文本片段(chunk)或完成标记
|
||||||
"""
|
"""
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
from langgraph.config import get_stream_writer
|
|
||||||
|
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
# 检查是否有注入的 End 节点前缀配置
|
|
||||||
writer = get_stream_writer()
|
|
||||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
|
||||||
|
|
||||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
|
||||||
if end_prefix:
|
|
||||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
|
||||||
|
|
||||||
if end_prefix:
|
|
||||||
# 渲染前缀(可能包含其他变量)
|
|
||||||
try:
|
|
||||||
rendered_prefix = self._render_template(end_prefix, state)
|
|
||||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
|
||||||
|
|
||||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
|
||||||
writer({
|
|
||||||
"type": "message", # End 相关的内容都是 message 类型
|
|
||||||
"node_id": "end", # 标记为 end 节点的输出
|
|
||||||
"chunk": rendered_prefix,
|
|
||||||
"full_content": rendered_prefix,
|
|
||||||
"chunk_index": 0,
|
|
||||||
"is_prefix": True # 标记这是前缀
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
last_chunk = None
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
@@ -284,12 +256,19 @@ class LLMNode(BaseNode):
|
|||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
last_chunk = chunk
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield content
|
yield {
|
||||||
|
"__final__": False,
|
||||||
|
"chunk": content
|
||||||
|
}
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"__final__": False,
|
||||||
|
"chunk": "",
|
||||||
|
"done": True
|
||||||
|
}
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- default
|
- default
|
||||||
- celery
|
- celery
|
||||||
|
- sandbox
|
||||||
depends_on:
|
depends_on:
|
||||||
- worker-memory
|
- worker-memory
|
||||||
- worker-document
|
- worker-document
|
||||||
@@ -63,5 +64,16 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- worker-memory
|
- worker-memory
|
||||||
|
|
||||||
|
sandbox:
|
||||||
|
image: redbear_sandbox:latest
|
||||||
|
container_name: sandbox
|
||||||
|
ports:
|
||||||
|
- "8194"
|
||||||
|
command: /code/.venv/bin/python main.py
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- sandbox
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
celery:
|
celery:
|
||||||
|
sandbox:
|
||||||
|
|||||||
Reference in New Issue
Block a user