fix(workflow): use loose rendering for end-node variables
This commit is contained in:
@@ -35,7 +35,7 @@ class WorkflowState(TypedDict):
|
|||||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||||
**x,
|
**x,
|
||||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||||
for k, v in y.items()}
|
for k, v in y.items()}
|
||||||
}]
|
}]
|
||||||
|
|
||||||
@@ -46,12 +46,12 @@ class WorkflowState(TypedDict):
|
|||||||
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
||||||
# Format: {node_id: business_result}
|
# Format: {node_id: business_result}
|
||||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||||
|
|
||||||
# Execution context
|
# Execution context
|
||||||
execution_id: str
|
execution_id: str
|
||||||
workspace_id: str
|
workspace_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
# Error information (for error edges)
|
# Error information (for error edges)
|
||||||
error: str | None
|
error: str | None
|
||||||
error_node: str | None
|
error_node: str | None
|
||||||
@@ -66,7 +66,7 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
所有节点类型都应该继承此基类,实现 execute 方法。
|
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
"""初始化节点
|
"""初始化节点
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class BaseNode(ABC):
|
|||||||
# 使用 or 运算符处理 None 值
|
# 使用 or 运算符处理 None 值
|
||||||
self.config = node_config.get("config") or {}
|
self.config = node_config.get("config") or {}
|
||||||
self.error_handling = node_config.get("error_handling") or {}
|
self.error_handling = node_config.get("error_handling") or {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
"""执行节点业务逻辑(非流式)
|
"""执行节点业务逻辑(非流式)
|
||||||
@@ -108,7 +108,7 @@ class BaseNode(ABC):
|
|||||||
>>> return {"message": "开始", "conversation_id": "xxx"}
|
>>> return {"message": "开始", "conversation_id": "xxx"}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
"""执行节点业务逻辑(流式)
|
"""执行节点业务逻辑(流式)
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ class BaseNode(ABC):
|
|||||||
result = await self.execute(state)
|
result = await self.execute(state)
|
||||||
# 默认实现:直接 yield 完成标记
|
# 默认实现:直接 yield 完成标记
|
||||||
yield {"__final__": True, "result": result}
|
yield {"__final__": True, "result": result}
|
||||||
|
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
"""节点是否支持流式输出
|
"""节点是否支持流式输出
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 检查子类是否重写了 execute_stream 方法
|
# 检查子类是否重写了 execute_stream 方法
|
||||||
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
||||||
|
|
||||||
def get_timeout(self) -> int:
|
def get_timeout(self) -> int:
|
||||||
"""获取超时时间(秒)
|
"""获取超时时间(秒)
|
||||||
|
|
||||||
@@ -156,7 +156,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
return 60
|
return 60
|
||||||
# return self.error_handling.get("timeout", 60)
|
# return self.error_handling.get("timeout", 60)
|
||||||
|
|
||||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行节点(带错误处理和输出包装,非流式)
|
"""执行节点(带错误处理和输出包装,非流式)
|
||||||
|
|
||||||
@@ -173,33 +173,33 @@ class BaseNode(ABC):
|
|||||||
标准化的状态更新字典
|
标准化的状态更新字典
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
timeout = self.get_timeout()
|
timeout = self.get_timeout()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用节点的业务逻辑
|
# 调用节点的业务逻辑
|
||||||
business_result = await asyncio.wait_for(
|
business_result = await asyncio.wait_for(
|
||||||
self.execute(state),
|
self.execute(state),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# 提取处理后的输出(调用子类的 _extract_output)
|
# 提取处理后的输出(调用子类的 _extract_output)
|
||||||
extracted_output = self._extract_output(business_result)
|
extracted_output = self._extract_output(business_result)
|
||||||
|
|
||||||
# 包装成标准输出格式
|
# 包装成标准输出格式
|
||||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||||
|
|
||||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||||
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||||
if isinstance(extracted_output, dict):
|
if isinstance(extracted_output, dict):
|
||||||
runtime_var = extracted_output
|
runtime_var = extracted_output
|
||||||
else:
|
else:
|
||||||
runtime_var = {"output": extracted_output}
|
runtime_var = {"output": extracted_output}
|
||||||
|
|
||||||
# 返回包装后的输出和运行时变量
|
# 返回包装后的输出和运行时变量
|
||||||
return {
|
return {
|
||||||
**wrapped_output,
|
**wrapped_output,
|
||||||
@@ -208,7 +208,7 @@ class BaseNode(ABC):
|
|||||||
},
|
},
|
||||||
"looping": state["looping"]
|
"looping": state["looping"]
|
||||||
}
|
}
|
||||||
|
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||||
@@ -217,7 +217,7 @@ class BaseNode(ABC):
|
|||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||||
return self._wrap_error(str(e), elapsed_time, state)
|
return self._wrap_error(str(e), elapsed_time, state)
|
||||||
|
|
||||||
async def run_stream(self, state: WorkflowState):
|
async def run_stream(self, state: WorkflowState):
|
||||||
"""Execute node with error handling and output wrapping (streaming)
|
"""Execute node with error handling and output wrapping (streaming)
|
||||||
|
|
||||||
@@ -240,40 +240,41 @@ class BaseNode(ABC):
|
|||||||
State updates with streaming buffer and final result
|
State updates with streaming buffer and final result
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
timeout = self.get_timeout()
|
timeout = self.get_timeout()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get LangGraph's stream writer for sending custom data
|
# Get LangGraph's stream writer for sending custom data
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
|
|
||||||
# Check if this is an End node
|
# Check if this is an End node
|
||||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||||
is_end_node = self.node_type == "end"
|
is_end_node = self.node_type == "end"
|
||||||
|
|
||||||
# Check if this node is adjacent to End node (for message type)
|
# Check if this node is adjacent to End node (for message type)
|
||||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||||
|
|
||||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||||
|
|
||||||
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
logger.debug(
|
||||||
|
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||||
|
|
||||||
# Accumulate complete result (for final wrapping)
|
# Accumulate complete result (for final wrapping)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_result = None
|
final_result = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# Stream chunks in real-time
|
# Stream chunks in real-time
|
||||||
loop_start = asyncio.get_event_loop().time()
|
loop_start = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
async for item in self.execute_stream(state):
|
async for item in self.execute_stream(state):
|
||||||
# Check timeout
|
# Check timeout
|
||||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
# Check if it's a completion marker
|
# Check if it's a completion marker
|
||||||
if isinstance(item, dict) and item.get("__final__"):
|
if isinstance(item, dict) and item.get("__final__"):
|
||||||
final_result = item["result"]
|
final_result = item["result"]
|
||||||
@@ -282,10 +283,10 @@ class BaseNode(ABC):
|
|||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
chunks.append(item)
|
chunks.append(item)
|
||||||
full_content = "".join(chunks)
|
full_content = "".join(chunks)
|
||||||
|
|
||||||
# Send chunks for all nodes (including End nodes for suffix)
|
# Send chunks for all nodes (including End nodes for suffix)
|
||||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||||
|
|
||||||
# 1. Send via stream writer (for real-time client updates)
|
# 1. Send via stream writer (for real-time client updates)
|
||||||
writer({
|
writer({
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
"type": chunk_type, # "message" or "node_chunk"
|
||||||
@@ -294,7 +295,7 @@ class BaseNode(ABC):
|
|||||||
"full_content": full_content,
|
"full_content": full_content,
|
||||||
"chunk_index": chunk_count
|
"chunk_index": chunk_count
|
||||||
})
|
})
|
||||||
|
|
||||||
# 2. Update streaming buffer in state (for downstream nodes)
|
# 2. Update streaming buffer in state (for downstream nodes)
|
||||||
# Only non-End nodes need streaming buffer
|
# Only non-End nodes need streaming buffer
|
||||||
if not is_end_node:
|
if not is_end_node:
|
||||||
@@ -313,7 +314,7 @@ class BaseNode(ABC):
|
|||||||
chunk_str = str(item)
|
chunk_str = str(item)
|
||||||
chunks.append(chunk_str)
|
chunks.append(chunk_str)
|
||||||
full_content = "".join(chunks)
|
full_content = "".join(chunks)
|
||||||
|
|
||||||
# Send chunks for all nodes
|
# Send chunks for all nodes
|
||||||
writer({
|
writer({
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
"type": chunk_type, # "message" or "node_chunk"
|
||||||
@@ -322,7 +323,7 @@ class BaseNode(ABC):
|
|||||||
"full_content": full_content,
|
"full_content": full_content,
|
||||||
"chunk_index": chunk_count
|
"chunk_index": chunk_count
|
||||||
})
|
})
|
||||||
|
|
||||||
# Only non-End nodes need streaming buffer
|
# Only non-End nodes need streaming buffer
|
||||||
if not is_end_node:
|
if not is_end_node:
|
||||||
yield {
|
yield {
|
||||||
@@ -334,23 +335,23 @@ class BaseNode(ABC):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||||
|
|
||||||
# Extract processed output (call subclass's _extract_output)
|
# Extract processed output (call subclass's _extract_output)
|
||||||
extracted_output = self._extract_output(final_result)
|
extracted_output = self._extract_output(final_result)
|
||||||
|
|
||||||
# Wrap final result
|
# Wrap final result
|
||||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||||
|
|
||||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||||
if isinstance(extracted_output, dict):
|
if isinstance(extracted_output, dict):
|
||||||
runtime_var = extracted_output
|
runtime_var = extracted_output
|
||||||
else:
|
else:
|
||||||
runtime_var = {"output": extracted_output}
|
runtime_var = {"output": extracted_output}
|
||||||
|
|
||||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||||
state_update = {
|
state_update = {
|
||||||
**final_output,
|
**final_output,
|
||||||
@@ -359,7 +360,7 @@ class BaseNode(ABC):
|
|||||||
},
|
},
|
||||||
"looping": state["looping"]
|
"looping": state["looping"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add streaming buffer for non-End nodes
|
# Add streaming buffer for non-End nodes
|
||||||
if not is_end_node:
|
if not is_end_node:
|
||||||
state_update["streaming_buffer"] = {
|
state_update["streaming_buffer"] = {
|
||||||
@@ -369,11 +370,11 @@ class BaseNode(ABC):
|
|||||||
"is_complete": True # Mark as complete
|
"is_complete": True # Mark as complete
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Finally yield state update
|
# Finally yield state update
|
||||||
# LangGraph will merge this into state
|
# LangGraph will merge this into state
|
||||||
yield state_update
|
yield state_update
|
||||||
|
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||||
@@ -384,12 +385,12 @@ class BaseNode(ABC):
|
|||||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||||
yield error_output
|
yield error_output
|
||||||
|
|
||||||
def _wrap_output(
|
def _wrap_output(
|
||||||
self,
|
self,
|
||||||
business_result: Any,
|
business_result: Any,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
state: WorkflowState
|
state: WorkflowState
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""将业务结果包装成标准输出格式
|
"""将业务结果包装成标准输出格式
|
||||||
|
|
||||||
@@ -403,13 +404,13 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 提取输入数据(用于记录)
|
# 提取输入数据(用于记录)
|
||||||
input_data = self._extract_input(state)
|
input_data = self._extract_input(state)
|
||||||
|
|
||||||
# 提取 token 使用情况(如果有)
|
# 提取 token 使用情况(如果有)
|
||||||
token_usage = self._extract_token_usage(business_result)
|
token_usage = self._extract_token_usage(business_result)
|
||||||
|
|
||||||
# 提取实际输出(去除元数据)
|
# 提取实际输出(去除元数据)
|
||||||
output = self._extract_output(business_result)
|
output = self._extract_output(business_result)
|
||||||
|
|
||||||
# 构建标准节点输出
|
# 构建标准节点输出
|
||||||
node_output = {
|
node_output = {
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
@@ -422,18 +423,18 @@ class BaseNode(ABC):
|
|||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"node_outputs": {
|
"node_outputs": {
|
||||||
self.node_id: node_output
|
self.node_id: node_output
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _wrap_error(
|
def _wrap_error(
|
||||||
self,
|
self,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
state: WorkflowState
|
state: WorkflowState
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""将错误包装成标准输出格式
|
"""将错误包装成标准输出格式
|
||||||
|
|
||||||
@@ -447,10 +448,10 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 查找错误边
|
# 查找错误边
|
||||||
error_edge = self._find_error_edge()
|
error_edge = self._find_error_edge()
|
||||||
|
|
||||||
# 提取输入数据
|
# 提取输入数据
|
||||||
input_data = self._extract_input(state)
|
input_data = self._extract_input(state)
|
||||||
|
|
||||||
# 构建错误输出
|
# 构建错误输出
|
||||||
node_output = {
|
node_output = {
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
@@ -463,7 +464,7 @@ class BaseNode(ABC):
|
|||||||
"token_usage": None,
|
"token_usage": None,
|
||||||
"error": error_message
|
"error": error_message
|
||||||
}
|
}
|
||||||
|
|
||||||
if error_edge:
|
if error_edge:
|
||||||
# 有错误边:记录错误并继续
|
# 有错误边:记录错误并继续
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -480,7 +481,7 @@ class BaseNode(ABC):
|
|||||||
# 无错误边:抛出异常停止工作流
|
# 无错误边:抛出异常停止工作流
|
||||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""提取节点输入数据(用于记录)
|
"""提取节点输入数据(用于记录)
|
||||||
|
|
||||||
@@ -494,7 +495,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 默认返回配置
|
# 默认返回配置
|
||||||
return {"config": self.config}
|
return {"config": self.config}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
"""从业务结果中提取实际输出
|
"""从业务结果中提取实际输出
|
||||||
|
|
||||||
@@ -508,7 +509,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 默认直接返回业务结果
|
# 默认直接返回业务结果
|
||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从业务结果中提取 token 使用情况
|
"""从业务结果中提取 token 使用情况
|
||||||
|
|
||||||
@@ -522,7 +523,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
# 默认返回 None
|
# 默认返回 None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _find_error_edge(self) -> dict[str, Any] | None:
|
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||||
"""查找错误边
|
"""查找错误边
|
||||||
|
|
||||||
@@ -533,8 +534,8 @@ class BaseNode(ABC):
|
|||||||
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||||
return edge
|
return edge
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str:
|
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
|
||||||
"""渲染模板
|
"""渲染模板
|
||||||
|
|
||||||
支持的变量命名空间:
|
支持的变量命名空间:
|
||||||
@@ -550,28 +551,28 @@ class BaseNode(ABC):
|
|||||||
渲染后的字符串
|
渲染后的字符串
|
||||||
"""
|
"""
|
||||||
from app.core.workflow.template_renderer import render_template
|
from app.core.workflow.template_renderer import render_template
|
||||||
|
|
||||||
# 处理 state 为 None 的情况
|
# 处理 state 为 None 的情况
|
||||||
if state is None:
|
if state is None:
|
||||||
state = {}
|
state = {}
|
||||||
|
|
||||||
# 使用变量池获取变量
|
# 使用变量池获取变量
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
|
|
||||||
# 构建完整的 variables 结构
|
# 构建完整的 variables 结构
|
||||||
variables = {
|
variables = {
|
||||||
"sys": pool.get_all_system_vars(),
|
"sys": pool.get_all_system_vars(),
|
||||||
"conv": pool.get_all_conversation_vars()
|
"conv": pool.get_all_conversation_vars()
|
||||||
}
|
}
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
template=template,
|
template=template,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
system_vars=pool.get_all_system_vars(),
|
system_vars=pool.get_all_system_vars(),
|
||||||
struct=struct
|
strict=strict
|
||||||
)
|
)
|
||||||
|
|
||||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||||
"""评估条件表达式
|
"""评估条件表达式
|
||||||
|
|
||||||
@@ -588,20 +589,20 @@ class BaseNode(ABC):
|
|||||||
布尔值结果
|
布尔值结果
|
||||||
"""
|
"""
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
|
||||||
# 处理 state 为 None 的情况
|
# 处理 state 为 None 的情况
|
||||||
if state is None:
|
if state is None:
|
||||||
state = {}
|
state = {}
|
||||||
|
|
||||||
# 使用变量池获取变量
|
# 使用变量池获取变量
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
|
|
||||||
# 构建完整的 variables 结构(包含 sys 和 conv)
|
# 构建完整的 variables 结构(包含 sys 和 conv)
|
||||||
variables = {
|
variables = {
|
||||||
"sys": pool.get_all_system_vars(),
|
"sys": pool.get_all_system_vars(),
|
||||||
"conv": pool.get_all_conversation_vars()
|
"conv": pool.get_all_conversation_vars()
|
||||||
}
|
}
|
||||||
|
|
||||||
return evaluate_condition(
|
return evaluate_condition(
|
||||||
expression=expression,
|
expression=expression,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
@@ -626,12 +627,12 @@ class BaseNode(ABC):
|
|||||||
>>> llm_output = pool.get("llm_qa.output")
|
>>> llm_output = pool.get("llm_qa.output")
|
||||||
"""
|
"""
|
||||||
return VariablePool(state)
|
return VariablePool(state)
|
||||||
|
|
||||||
def get_variable(
|
def get_variable(
|
||||||
self,
|
self,
|
||||||
selector: list[str] | str,
|
selector: list[str] | str,
|
||||||
state: WorkflowState,
|
state: WorkflowState,
|
||||||
default: Any = None
|
default: Any = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""获取变量值(便捷方法)
|
"""获取变量值(便捷方法)
|
||||||
|
|
||||||
@@ -650,7 +651,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
return pool.get(selector, default=default)
|
return pool.get(selector, default=default)
|
||||||
|
|
||||||
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||||
"""检查变量是否存在(便捷方法)
|
"""检查变量是否存在(便捷方法)
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
if output_template:
|
if output_template:
|
||||||
output = self._render_template(output_template, state, struct=False)
|
output = self._render_template(output_template, state, strict=False)
|
||||||
else:
|
else:
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
|
|
||||||
@@ -156,6 +156,16 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
if not output_template:
|
if not output_template:
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
writer = get_stream_writer()
|
||||||
|
writer({
|
||||||
|
"type": "message", # End node output uses message type
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"chunk": "",
|
||||||
|
"full_content": output,
|
||||||
|
"chunk_index": 1,
|
||||||
|
"is_suffix": False
|
||||||
|
})
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -190,7 +200,7 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
if upstream_llm_ref_index is None:
|
if upstream_llm_ref_index is None:
|
||||||
# No reference to direct upstream LLM node, output complete template content
|
# No reference to direct upstream LLM node, output complete template content
|
||||||
output = self._render_template(output_template, state)
|
output = self._render_template(output_template, state, strict=False)
|
||||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||||
|
|
||||||
# Send complete content via writer (as a single message chunk)
|
# Send complete content via writer (as a single message chunk)
|
||||||
@@ -246,7 +256,7 @@ class EndNode(BaseNode):
|
|||||||
suffix = "".join(suffix_parts)
|
suffix = "".join(suffix_parts)
|
||||||
|
|
||||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||||
full_output = self._render_template(output_template, state)
|
full_output = self._render_template(output_template, state, strict=False)
|
||||||
|
|
||||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||||
@@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SafeUndefined(Undefined):
|
||||||
|
"""访问未定义属性不会报错,返回空字符串"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error
|
||||||
|
__getitem__ = __getattr__ = _fail_with_undefined_error
|
||||||
|
__str__ = __repr__ = lambda self: ""
|
||||||
|
|
||||||
|
|
||||||
class TemplateRenderer:
|
class TemplateRenderer:
|
||||||
"""模板渲染器"""
|
"""模板渲染器"""
|
||||||
|
|
||||||
@@ -21,8 +34,9 @@ class TemplateRenderer:
|
|||||||
Args:
|
Args:
|
||||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||||
"""
|
"""
|
||||||
|
self.strict = strict
|
||||||
self.env = Environment(
|
self.env = Environment(
|
||||||
undefined=StrictUndefined if strict else Undefined,
|
undefined=StrictUndefined if strict else SafeUndefined,
|
||||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,12 +83,17 @@ class TemplateRenderer:
|
|||||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||||
|
if self.strict:
|
||||||
context = {
|
context = defaultdict(dict)
|
||||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
context["conv"] = conv_vars
|
||||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
context["nodes"] = node_outputs
|
||||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
context["sys"] = {**(system_vars or {}), **sys_vars}
|
||||||
}
|
else:
|
||||||
|
context = {
|
||||||
|
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||||
|
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||||
|
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||||
|
}
|
||||||
|
|
||||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||||
# 将所有节点输出添加到顶层上下文
|
# 将所有节点输出添加到顶层上下文
|
||||||
@@ -141,12 +160,12 @@ def render_template(
|
|||||||
variables: dict[str, Any],
|
variables: dict[str, Any],
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, Any],
|
||||||
system_vars: dict[str, Any] | None = None,
|
system_vars: dict[str, Any] | None = None,
|
||||||
struct: bool = True
|
strict: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""渲染模板(便捷函数)
|
"""渲染模板(便捷函数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
struct: 渲染模式
|
strict: 严格模式
|
||||||
template: 模板字符串
|
template: 模板字符串
|
||||||
variables: 用户变量
|
variables: 用户变量
|
||||||
node_outputs: 节点输出
|
node_outputs: 节点输出
|
||||||
@@ -164,7 +183,7 @@ def render_template(
|
|||||||
... )
|
... )
|
||||||
'请分析: 这是一段文本'
|
'请分析: 这是一段文本'
|
||||||
"""
|
"""
|
||||||
renderer = TemplateRenderer(strict=struct)
|
renderer = TemplateRenderer(strict=strict)
|
||||||
return renderer.render(template, variables, node_outputs, system_vars)
|
return renderer.render(template, variables, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ nodes:
|
|||||||
type: end
|
type: end
|
||||||
name: 结束
|
name: 结束
|
||||||
config:
|
config:
|
||||||
output: "{{llm_qa.output}}"
|
output: "{{ llm_qa.output }}"
|
||||||
position:
|
position:
|
||||||
x: 900
|
x: 900
|
||||||
y: 100
|
y: 100
|
||||||
|
|||||||
Reference in New Issue
Block a user