Merge pull request #210 from SuanmoSuanyangTechnology/fix/workflow-stream
fix(workflow): fix activation and branch control issues in streaming output
This commit is contained in:
@@ -11,16 +11,12 @@ from typing import Any
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
# from app.core.tools.executor import ToolExecutor
|
||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
||||
# from app.db import get_db
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +51,8 @@ class WorkflowExecutor:
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
@@ -127,7 +125,6 @@ class WorkflowExecutor:
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"streaming_buffer": {}, # 流式缓冲区
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
@@ -139,9 +136,8 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
@@ -161,6 +157,21 @@ class WorkflowExecutor:
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_end_activate(self, node_id):
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(node_id)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
@staticmethod
|
||||
def _trans_output_string(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return "\n".join(content)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -173,6 +184,7 @@ class WorkflowExecutor:
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
graph = builder.build()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
@@ -205,14 +217,34 @@ class WorkflowExecutor:
|
||||
try:
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
|
||||
full_content = ''
|
||||
for end_info in self.end_outputs.values():
|
||||
output_template = "".join([output.literal for output in end_info.outputs])
|
||||
full_content += render_template(
|
||||
output_template,
|
||||
result.get("variables", {}),
|
||||
result.get("runtime_vars", {}),
|
||||
strict=False
|
||||
)
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time)
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -273,6 +305,7 @@ class WorkflowExecutor:
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
full_content = ''
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
@@ -293,21 +326,27 @@ class WorkflowExecutor:
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type in ("message", "node_chunk"):
|
||||
if event_type == "node_chunk":
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
continue
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_node(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
else:
|
||||
full_content += data.get("chunk")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix"),
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
|
||||
elif event_type == "node_error":
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
@@ -376,14 +415,109 @@ class WorkflowExecutor:
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
# TODO:流式输出点
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
wait = False
|
||||
state = graph.get_state(config=self.checkpoint_config)
|
||||
node_outputs = state.values.get("runtime_vars", {})
|
||||
for _ in data.keys():
|
||||
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
||||
|
||||
while self.activate_end and not wait:
|
||||
message = ''
|
||||
logger.info(self.activate_end)
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
while content.activate:
|
||||
if not content.is_variable:
|
||||
full_content += content.literal
|
||||
message += content.literal
|
||||
else:
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
variables={},
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
break
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if end_info.cursor != len(end_info.outputs):
|
||||
wait = True
|
||||
else:
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
while self.activate_end:
|
||||
message = ''
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if not content.is_variable:
|
||||
message += content.literal
|
||||
else:
|
||||
node_outputs = result.get("runtime_vars", {})
|
||||
variables = result.get("variables", {})
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
variables=variables,
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
if self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(result)
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
@@ -392,7 +526,7 @@ class WorkflowExecutor:
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time)
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -414,31 +548,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
@@ -529,178 +638,3 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
# """获取工作流可用的工具列表
|
||||
#
|
||||
# Args:
|
||||
# workspace_id: 工作空间ID
|
||||
# user_id: 用户ID
|
||||
#
|
||||
# Returns:
|
||||
# 可用工具列表
|
||||
# """
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.warning("工具管理系统不可用")
|
||||
# return []
|
||||
#
|
||||
# try:
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具注册表
|
||||
# registry = ToolRegistry(db)
|
||||
#
|
||||
# # 注册内置工具类
|
||||
# from app.core.tools.builtin import (
|
||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
# )
|
||||
# registry.register_tool_class(DateTimeTool)
|
||||
# registry.register_tool_class(JsonTool)
|
||||
# registry.register_tool_class(BaiduSearchTool)
|
||||
# registry.register_tool_class(MinerUTool)
|
||||
# registry.register_tool_class(TextInTool)
|
||||
#
|
||||
# # 获取活跃的工具
|
||||
# import uuid
|
||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
#
|
||||
# # 转换为Langchain工具
|
||||
# langchain_tools = []
|
||||
# for tool_info in active_tools:
|
||||
# try:
|
||||
# tool_instance = registry.get_tool(tool_info.id)
|
||||
# if tool_instance:
|
||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
# langchain_tools.append(langchain_tool)
|
||||
# except Exception as e:
|
||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
#
|
||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
# return langchain_tools
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流工具失败: {e}")
|
||||
# return []
|
||||
#
|
||||
#
|
||||
# class ToolWorkflowNode:
|
||||
# """工具工作流节点 - 在工作流中执行工具"""
|
||||
#
|
||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
||||
# """初始化工具节点
|
||||
#
|
||||
# Args:
|
||||
# node_config: 节点配置
|
||||
# workflow_config: 工作流配置
|
||||
# """
|
||||
# self.node_config = node_config
|
||||
# self.workflow_config = workflow_config
|
||||
# self.tool_id = node_config.get("tool_id")
|
||||
# self.tool_parameters = node_config.get("parameters", {})
|
||||
#
|
||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
# """执行工具节点"""
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.error("工具管理系统不可用")
|
||||
# state["error"] = "工具管理系统不可用"
|
||||
# return state
|
||||
#
|
||||
# try:
|
||||
# from sqlalchemy.orm import Session
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具执行器
|
||||
# registry = ToolRegistry(db)
|
||||
# executor = ToolExecutor(db, registry)
|
||||
#
|
||||
# # 准备参数(支持变量替换)
|
||||
# parameters = self._prepare_parameters(state)
|
||||
#
|
||||
# # 执行工具
|
||||
# result = await executor.execute_tool(
|
||||
# tool_id=self.tool_id,
|
||||
# parameters=parameters,
|
||||
# user_id=uuid.UUID(state["user_id"]),
|
||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
||||
# )
|
||||
#
|
||||
# # 更新状态
|
||||
# node_id = self.node_config.get("id")
|
||||
# if result.success:
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "output": result.data,
|
||||
# "execution_time": result.execution_time,
|
||||
# "token_usage": result.token_usage
|
||||
# }
|
||||
#
|
||||
# # 更新运行时变量
|
||||
# if isinstance(result.data, dict):
|
||||
# for key, value in result.data.items():
|
||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
# else:
|
||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
# else:
|
||||
# state["error"] = result.error
|
||||
# state["error_node"] = node_id
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "error": result.error,
|
||||
# "execution_time": result.execution_time
|
||||
# }
|
||||
#
|
||||
# return state
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"工具节点执行失败: {e}")
|
||||
# state["error"] = str(e)
|
||||
# state["error_node"] = self.node_config.get("id")
|
||||
# return state
|
||||
#
|
||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
# """准备工具参数(支持变量替换)"""
|
||||
# parameters = {}
|
||||
#
|
||||
# for key, value in self.tool_parameters.items():
|
||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# # 变量替换
|
||||
# var_path = value[2:-1]
|
||||
#
|
||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
# if "." in var_path:
|
||||
# parts = var_path.split(".")
|
||||
# current = state.get("variables", {})
|
||||
#
|
||||
# for part in parts:
|
||||
# if isinstance(current, dict) and part in current:
|
||||
# current = current[part]
|
||||
# else:
|
||||
# # 尝试从运行时变量获取
|
||||
# runtime_key = ".".join(parts)
|
||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
# break
|
||||
#
|
||||
# parameters[key] = current
|
||||
# else:
|
||||
# # 简单变量
|
||||
# variables = state.get("variables", {})
|
||||
# parameters[key] = variables.get(var_path, value)
|
||||
# else:
|
||||
# parameters[key] = value
|
||||
#
|
||||
# return parameters
|
||||
#
|
||||
#
|
||||
# # 注册工具节点到NodeFactory(如果存在)
|
||||
# try:
|
||||
# from app.core.workflow.nodes import NodeFactory
|
||||
# if hasattr(NodeFactory, 'register_node_type'):
|
||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
# logger.info("工具节点已注册到工作流系统")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"注册工具节点失败: {e}")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
@@ -15,6 +18,153 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
def depends_on_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if this output segment depends on a specific node's variable.
|
||||
|
||||
This method examines the `literal` of the output segment to see if it
|
||||
contains a variable placeholder referencing the given node in the form:
|
||||
|
||||
{{ node_id.field_name }}
|
||||
|
||||
It uses a regular expression to match the exact node ID, avoiding
|
||||
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
- True if the segment contains a variable referencing the given node.
|
||||
- False otherwise.
|
||||
|
||||
Example:
|
||||
literal = "{{node1.name}}"
|
||||
|
||||
depends_on_node("node1") -> True
|
||||
depends_on_node("node2") -> False
|
||||
|
||||
Usage:
|
||||
This method is primarily used in stream mode to determine whether
|
||||
a particular variable output segment should be activated when a
|
||||
specific upstream node completes execution.
|
||||
"""
|
||||
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
pattern = re.compile(variable_pattern)
|
||||
match = pattern.search(self.literal)
|
||||
if match:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This structure controls:
|
||||
- whether the End node output is globally active
|
||||
- which upstream branch nodes are responsible for activation
|
||||
- how each output segment behaves in streaming mode
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation state of the End node output.\n"
|
||||
"If False, no output should be emitted until all control nodes are resolved."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: list[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"List of upstream branch node IDs that control this End node.\n"
|
||||
"Each node must signal completion before output becomes active."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description="Ordered list of output segments parsed from the output template."
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates how many output segments have already been emitted."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, node_id):
|
||||
"""
|
||||
Update activation state based on an upstream node completion.
|
||||
|
||||
This method is typically called when a branch/control node finishes execution.
|
||||
|
||||
Behavior:
|
||||
1. If the node is a control node:
|
||||
- Remove it from `control_nodes`
|
||||
- If all control nodes are resolved, activate the entire output
|
||||
|
||||
2. Activate variable output segments that depend on this node:
|
||||
- If an output segment is a variable
|
||||
- And its literal references the completed node_id
|
||||
- Mark that segment as active
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if node_id in self.control_nodes:
|
||||
self.control_nodes.remove(node_id)
|
||||
|
||||
# All branch constraints resolved → enable output
|
||||
if not self.control_nodes:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_node(node_id)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,6 +179,12 @@ class GraphBuilder:
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
self._analyze_end_node_output()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
@@ -43,79 +199,182 @@ class GraphBuilder:
|
||||
def edges(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("edges", [])
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""
|
||||
Analyze the prefix configuration for End nodes.
|
||||
def get_node_type(self, node_id: str) -> str:
|
||||
"""Retrieve the type of node given its ID.
|
||||
|
||||
This function scans each End node's output template, identifies
|
||||
references to its direct upstream nodes, and extracts the prefix
|
||||
string appearing before the first reference.
|
||||
Args:
|
||||
node_id (str): The unique identifier of the node.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
||||
str: The type of the node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no node with the given `node_id` exists.
|
||||
"""
|
||||
import re
|
||||
try:
|
||||
return self.node_map[node_id]["type"]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
||||
"""Find upstream branch nodes for a given target node in the workflow graph.
|
||||
|
||||
# 找到所有 End 节点
|
||||
This method identifies all upstream control (branch) nodes that can affect
|
||||
the execution of `target_node`. If `target_node` is reachable from a start
|
||||
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
||||
|
||||
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
||||
and non-branch nodes, recursively traversing upstream through non-branch
|
||||
nodes. If any non-branch upstream path does not lead to a branch node,
|
||||
the result will indicate that no valid upstream branch node exists.
|
||||
|
||||
Args:
|
||||
target_node (str): The identifier of the target node.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[str]]:
|
||||
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
||||
one branch node; False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
||||
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
||||
"""
|
||||
source_nodes = [
|
||||
edge.get("source")
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||
return False, tuple()
|
||||
|
||||
branch_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_id in source_nodes:
|
||||
if self.get_node_type(node_id) in BRANCH_NODES:
|
||||
branch_nodes.append(node_id)
|
||||
else:
|
||||
non_branch_nodes.append(node_id)
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||
has_branch = has_branch and node_has_branch
|
||||
if not has_branch:
|
||||
break
|
||||
branch_nodes.extend(nodes)
|
||||
if not has_branch:
|
||||
branch_nodes = []
|
||||
|
||||
return has_branch, tuple(set(branch_nodes))
|
||||
|
||||
def _analyze_end_node_output(self):
|
||||
"""
|
||||
Analyze output templates of all End nodes and generate StreamOutputConfig.
|
||||
|
||||
This method is responsible for parsing the `output` field of End nodes,
|
||||
splitting literal text and variable placeholders (e.g. {{ node.field }}),
|
||||
and determining whether each output segment should be activated immediately
|
||||
or controlled by upstream branch nodes.
|
||||
|
||||
In stream mode:
|
||||
- If the End node is controlled by any upstream branch node, the output
|
||||
will be initially inactive and controlled by those branch nodes.
|
||||
- Otherwise, the output is activated immediately.
|
||||
|
||||
In non-stream mode:
|
||||
- All outputs are activated by default.
|
||||
"""
|
||||
|
||||
# Collect all End nodes in the workflow
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
|
||||
# Iterate through each End node to analyze its output
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
config = end_node.get("config", {})
|
||||
output = config.get("output")
|
||||
|
||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
# Skip End nodes without output configuration
|
||||
if not output:
|
||||
continue
|
||||
|
||||
# Find all node references in the template
|
||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||
|
||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
|
||||
# Identify all direct upstream nodes connected to the End node
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
# Split output into ordered segments
|
||||
output_template = list(re.findall(pattern, output))
|
||||
|
||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
||||
# Determine whether each segment is literal text
|
||||
# True -> literal (can be directly output)
|
||||
# False -> variable placeholder (needs runtime value)
|
||||
output_flag = [
|
||||
not bool(variable_pattern.match(item))
|
||||
for item in output_template
|
||||
]
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
||||
# Stream mode: output activation depends on upstream branch nodes
|
||||
if self.stream:
|
||||
# Find upstream branch nodes that can control this End node
|
||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
# Build StreamOutputConfig for this End node
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
# If there is no upstream branch, output is active immediately
|
||||
activate=not has_branch,
|
||||
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=list(control_nodes),
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
# Literal text can be activated immediately unless blocked by branch
|
||||
activate=activate,
|
||||
# Variable segments are marked explicitly
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
# Cursor for streaming output (initially 0)
|
||||
cursor=0
|
||||
)
|
||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||
f"activate: {not has_branch}, "
|
||||
f"control_nodes: {control_nodes},"
|
||||
f"output: {output_template},"
|
||||
f"output_activate: {output_flag}")
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
# Non-stream mode: all outputs are activated by default
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
activate=True,
|
||||
control_nodes=[],
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
activate=True,
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
cursor=0
|
||||
)
|
||||
|
||||
def add_nodes(self):
|
||||
"""Add all nodes from the workflow configuration to the state graph.
|
||||
@@ -135,9 +394,6 @@ class GraphBuilder:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Analyze End node prefixes if in stream mode
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
@@ -171,17 +427,6 @@ class GraphBuilder:
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Inject End node prefix configuration if in stream mode
|
||||
if self.stream and node_id in end_prefixes:
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"Injected End prefix for node {node_id}")
|
||||
|
||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
||||
if self.stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
||||
|
||||
# Wrap node's run method to avoid closure issues
|
||||
if self.stream:
|
||||
# Stream mode: create an async generator function
|
||||
@@ -261,6 +506,7 @@ class GraphBuilder:
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
|
||||
@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# Streaming buffer (stores real-time streaming output of nodes)
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
@@ -300,7 +296,7 @@ class BaseNode(ABC):
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"跳过节点{self.node_id}")
|
||||
logger.info(f"jump node: {self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
@@ -313,19 +309,6 @@ class BaseNode(ABC):
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(
|
||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
@@ -340,66 +323,25 @@ class BaseNode(ABC):
|
||||
raise TimeoutError()
|
||||
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
if item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
else:
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
content = str(item.get("chunk"))
|
||||
done = item.get("done", False)
|
||||
chunks.append(content)
|
||||
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"type": "node_chunk",
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
"chunk": content,
|
||||
"done": done
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
@@ -426,16 +368,6 @@ class BaseNode(ABC):
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update | self.trans_activate(state)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
__all__ = ["CodeNode"]
|
||||
|
||||
@@ -7,7 +7,6 @@ from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sympy.physics.vector import vlatex
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -5,10 +5,8 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,24 +35,8 @@ class EndNode(BaseNode):
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
else:
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
])
|
||||
output = "工作流已完成"
|
||||
output = ""
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
@@ -63,274 +45,3 @@ class EndNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
return output
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
# 匹配 {{node_id.xxx}} 格式
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
parts.append({
|
||||
"type": "dynamic",
|
||||
"node_id": node_id,
|
||||
"field": field,
|
||||
"raw": ref
|
||||
})
|
||||
else:
|
||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
last_end = end
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""Execute End node business logic (streaming)
|
||||
|
||||
Smart output strategy:
|
||||
1. Check if template references a direct upstream LLM node
|
||||
2. If yes, only output the part AFTER that reference (suffix)
|
||||
3. Prefix and LLM content have already been sent during LLM node streaming
|
||||
|
||||
Note: Only LLM nodes get this special treatment. Other node types output normally.
|
||||
|
||||
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- Direct upstream LLM node is llm_qa
|
||||
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
|
||||
- LLM content was streamed during LLM node execution
|
||||
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
|
||||
|
||||
Args:
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
Completion marker
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": "",
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
}
|
||||
])
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Find direct upstream LLM nodes
|
||||
direct_upstream_llm_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
break
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
|
||||
|
||||
# Parse template parts
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
# Find the first reference to a direct upstream LLM node
|
||||
upstream_llm_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
|
||||
upstream_llm_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_llm_ref_index is None:
|
||||
# No reference to direct upstream LLM node, output complete template content
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
# Send complete content via writer (as a single message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": output,
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(
|
||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_llm_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# Use VariablePool to get variable value
|
||||
pool = self.get_variable_pool(state)
|
||||
try:
|
||||
# Try to get variable value with default empty string
|
||||
content = pool.get([node_id, field], default="")
|
||||
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
||||
content = ""
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_output
|
||||
}
|
||||
])
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
"chunk_index": 1,
|
||||
"is_suffix": True
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: IfElseNodeConfig | None= None
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
|
||||
@@ -7,18 +7,18 @@ LLM 节点实现
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -231,42 +231,14 @@ class LLMNode(BaseNode):
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
"chunk_index": 0,
|
||||
"is_prefix": True # 标记这是前缀
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
@@ -284,12 +256,19 @@ class LLMNode(BaseNode):
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Literal
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
@@ -12,7 +10,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: UUID = Field(
|
||||
config_id: UUID | int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
||||
return await MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class ClassifierConfig(BaseModel):
|
||||
"""分类器节点配置"""
|
||||
|
||||
@@ -13,7 +14,7 @@ class ClassifierConfig(BaseModel):
|
||||
|
||||
class QuestionClassifierNodeConfig(BaseNodeConfig):
|
||||
"""问题分类器节点配置"""
|
||||
|
||||
|
||||
model_id: uuid.UUID = Field(..., description="LLM模型ID")
|
||||
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
|
||||
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")
|
||||
|
||||
@@ -18,30 +18,30 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
"""问题分类器节点"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
|
||||
return RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
@@ -64,7 +64,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
@@ -74,11 +74,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
categories = self.typed_config.categories or []
|
||||
category_names = [class_item.class_name.strip() for class_item in categories]
|
||||
category_count = len(category_names)
|
||||
|
||||
|
||||
if not question:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE}"
|
||||
f"分类总数: {category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
if category_count > 0:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.tool.node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||
|
||||
@@ -16,11 +16,11 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
@@ -28,21 +28,21 @@ class ToolNode(BaseNode):
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
|
||||
if not tenant_id:
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
@@ -55,9 +55,9 @@ class ToolNode(BaseNode):
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
tool_service = ToolService(db)
|
||||
@@ -79,7 +79,7 @@ class ToolNode(BaseNode):
|
||||
else:
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user