Merge branch 'refs/heads/develop' into fix/memory_bug_fix

This commit is contained in:
lixinyue
2026-01-21 11:53:52 +08:00
7 changed files with 296 additions and 182 deletions

View File

@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
@@ -661,6 +661,11 @@ async def draft_run(
data=result,
msg="工作流任务执行成功"
)
else:
return fail(
msg="未知应用类型",
code=422
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")

View File

@@ -54,6 +54,8 @@ class WorkflowExecutor:
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
self.start_node_id = None
self.checkpoint_config = RunnableConfig(
configurable={
"thread_id": uuid.uuid4(),
@@ -131,77 +133,12 @@ class WorkflowExecutor:
for node in self.workflow_config.get("nodes")
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
], # loop, iteration node id
"looping": False # loop runing flag, only use in loop node,not use in main loop
"looping": False, # loop runing flag, only use in loop node,not use in main loop
"activate": {
self.start_node_id: True
}
}
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def _build_final_output(self, result, elapsed_time):
node_outputs = result.get("node_outputs", {})
final_output = self._extract_final_output(node_outputs)
@@ -231,10 +168,12 @@ class WorkflowExecutor:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
graph = GraphBuilder(
builder = GraphBuilder(
self.workflow_config,
stream=stream,
).build()
)
self.start_node_id = builder.start_node_id
graph = builder.build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
@@ -375,13 +314,15 @@ class WorkflowExecutor:
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task":
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
if not inputv.get("activate", {}).get(node_name):
continue
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[NODE-START] Node starts execution: {node_name} "
f"- execution_id: {self.execution_id}")
@@ -390,18 +331,17 @@ class WorkflowExecutor:
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
}
}
elif event_type == "task_result":
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
if not result.get("activate", {}).get(node_name):
continue
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[NODE-END] Node execution completed: {node_name} "
f"- execution_id: {self.execution_id}")
@@ -410,7 +350,7 @@ class WorkflowExecutor:
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
"state": result.get("node_outputs", {}).get(node_name),
}

View File

@@ -1,14 +1,16 @@
import logging
import uuid
from collections import defaultdict
from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Send
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
logger = logging.getLogger(__name__)
@@ -28,7 +30,10 @@ class GraphBuilder:
self.start_node_id = None
self.end_node_ids = []
self.graph: StateGraph | CompiledStateGraph | None = None
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property
def nodes(self) -> list[dict[str, Any]]:
@@ -39,74 +44,98 @@ class GraphBuilder:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
"""
Analyze the prefix configuration for End nodes.
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
This function scans each End node's output template, identifies
references to its direct upstream nodes, and extracts the prefix
string appearing before the first reference.
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
tuple:
- dict[str, str]: Mapping from upstream node ID to its End node prefix
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} End 节点")
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
if not output_template:
continue
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} {{ node_id.xxx }} 格式(支持空格)
# Find all node references in the template
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
# 找到所有直接连接到 End 节点的上游节点
# Identify all direct upstream nodes connected to the End node
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
logger.info(f"[Prefix Analysis] "
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
logger.info(f"[Prefix Analysis] "
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def add_nodes(self):
"""Add all nodes from the workflow configuration to the state graph.
This method handles:
- Creation of node instances using NodeFactory.
- Special handling for start, end, and cycle nodes.
- Injection of End node prefixes for streaming mode.
- Marking nodes as adjacent to End nodes if referenced.
- Wrapping node run methods as async functions or async generators
depending on streaming mode.
Notes:
Loop nodes (nodes with `cycle` property) are handled separately
via CycleGraphNode when building subgraphs.
Returns:
None
"""
# Analyze End node prefixes if in stream mode
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
@@ -114,21 +143,21 @@ class GraphBuilder:
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# 记录 start end 节点 ID
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# 创建节点实例(现在 start end 也会被创建)
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
if node_type in BRANCH_NODES:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
@@ -142,26 +171,23 @@ class GraphBuilder:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
# Inject End node prefix configuration if in stream mode
if self.stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
logger.info(f"Injected End prefix for node {node_id}")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
# Mark nodes as adjacent and referenced to End node in stream mode
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
# Wrap node's run method to avoid closure issues
if self.stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
# Stream mode: create an async generator function
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
@@ -169,7 +195,7 @@ class GraphBuilder:
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
# Non-stream mode: create an async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
@@ -178,45 +204,110 @@ class GraphBuilder:
self.graph.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
logger.debug(f"Added node: {node_id} (type={node_type}, stream={self.stream})")
def add_edges(self):
"""Add all edges (normal, waiting, and conditional) to the state graph.
This method handles:
- Connecting the START node to the workflow's start node.
- Collecting waiting edges for nodes with multiple sources.
- Collecting conditional edges for routing to NOP nodes.
- Adding NOP nodes for conditional branches to allow later merging.
- Wrapping routing logic in a router function that evaluates conditions.
- Connecting End nodes to the global END node.
Notes:
- NOP nodes are used to ensure that multiple branches can merge
correctly without modifying the workflow state.
- Waiting edges are automatically handled by LangGraph to schedule
nodes only after all sources are activated.
Returns:
None
"""
# Connect the START node to the workflow's start node
if self.start_node_id:
self.graph.add_edge(START, self.start_node_id)
logger.debug(f"添加边: START -> {self.start_node_id}")
logger.debug(f"Added edge: START -> {self.start_node_id}")
# Collect all sources for each target node for normal/waiting edges
waiting_edges = defaultdict(list)
# Collect all conditional edges for each source node to construct routing
conditional_edges = defaultdict(list)
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
edge_type = edge.get("type")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
# Skip error edges (handled within nodes)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
# Conditional edges: group by source node
conditional_edges[source].append({
"target": target,
"condition": condition,
"label": edge.get("label")
})
else:
# Normal edges: group by target node (used for waiting edges)
waiting_edges[target].append(source)
def router_fn(state: WorkflowState):
# Add conditional edges
for source_node, branches in conditional_edges.items():
def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets):
def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE
return {
"activate": {
node_id: s["activate"][node_name]
for node_id in targets
}
}
return node
unique_branch = {}
for branch in branch_list:
if branch.get("label") not in unique_branch.keys():
nop_node_name = f"nop_{uuid.uuid4().hex[:8]}"
logger.info(f"Binding NOP: {source_node} {branch.get('label')} -> {nop_node_name}")
unique_branch[branch["label"]] = {
"condition": branch["condition"],
"node": {
"name": nop_node_name,
},
"target": [branch["target"]]
}
else:
unique_branch[branch["label"]]["target"].append(branch["target"])
# Add NOP nodes and connect them to downstream nodes
for label, branch_info in unique_branch.items():
self.graph.add_node(
branch_info["node"]["name"],
make_branch_node(
branch_info["node"]["name"],
branch_info["target"]
)
)
for target in branch_info["target"]:
waiting_edges[target].append(branch_info["node"]["name"])
def router_fn(state: WorkflowState) -> list[Send]:
branch_activate = []
new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
for label, branch in unique_branch.items():
if evaluate_condition(
cond,
branch["condition"],
state.get("variables", {}),
state.get("runtime_vars", {}),
{
@@ -225,30 +316,45 @@ class GraphBuilder:
"user_id": state.get("user_id")
}
):
return tgt
return END
logger.debug(f"Conditional routing {src}: selected branch {label}")
new_state["activate"][branch["node"]["name"]] = True
continue
new_state["activate"][branch["node"]["name"]] = False
for label, branch in unique_branch.items():
branch_activate.append(
Send(
branch['node']['name'],
new_state
)
)
return branch_activate
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
return router_fn
# Dynamically set function name
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{src}"
return router_fn
router_fn = make_router(condition, target)
self.graph.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
router_fn = make_router(source_node, branches)
self.graph.add_conditional_edges(source_node, router_fn)
logger.debug(f"Added conditional edges: {source_node} -> {[b['target'] for b in branches]}")
# Add normal/waiting edges
for target, sources in waiting_edges.items():
if len(sources) == 1:
# Single source: normal edge
self.graph.add_edge(sources[0], target)
logger.debug(f"Added edge: {sources[0]} -> {target}")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# Multiple sources: waiting edge
self.graph.add_edge(sources, target)
logger.debug(f"Added waiting edge: {sources} -> {target}")
# 从 end 节点连接到 END
# Connect End nodes to the global END node
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
logger.debug(f"Added edge: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer)
self.graph = self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:

View File

@@ -7,18 +7,26 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, AsyncGenerator
from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
def merget_activate_state(x, y):
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
class WorkflowState(TypedDict):
"""Workflow state
@@ -60,6 +68,9 @@ class WorkflowState(TypedDict):
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# node activate status
activate: Annotated[dict[str, bool], merget_activate_state]
class BaseNode(ABC):
"""节点基类
@@ -84,6 +95,47 @@ class BaseNode(ABC):
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
self.variable_updater = False
def check_activate(self, state: WorkflowState):
"""Check if the current node is activated in the workflow state.
Args:
state (WorkflowState): The current workflow state containing the 'activate' dict.
Returns:
bool: True if the node is activated, False otherwise.
"""
return state["activate"][self.node_id]
def trans_activate(self, state: WorkflowState):
"""Transform the activation state for downstream nodes.
This method collects all downstream nodes (excluding branch nodes)
connected to the current node and returns a dict indicating whether
each of these nodes should be activated based on the current node's state.
The current node itself is also included in the returned activation dict.
Args:
state (WorkflowState): The current workflow state.
Returns:
dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False).
"""
edges = self.workflow_config.get("edges")
under_stream_nodes = [
edge.get("target")
for edge in edges
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
]
return {
"activate": {
node_id: self.check_activate(state)
for node_id in under_stream_nodes
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
@@ -99,13 +151,13 @@ class BaseNode(ABC):
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
>>> {"message": "开始", "conversation_id": "xxx"}
"""
pass
@@ -126,14 +178,14 @@ class BaseNode(ABC):
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
# 流式 LLM 节点
full_response = ""
async for chunk in llm.astream(prompt):
full_response += chunk
yield chunk # yield 文本片段
# 最后 yield 完成标记
yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
@@ -146,7 +198,7 @@ class BaseNode(ABC):
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
return self.__class__.execute_stream is not BaseNode.execute_stream
def get_timeout(self) -> int:
"""获取超时时间(秒)
@@ -172,6 +224,9 @@ class BaseNode(ABC):
Returns:
标准化的状态更新字典
"""
if not self.check_activate(state):
return self.trans_activate(state)
import time
start_time = time.time()
@@ -204,12 +259,11 @@ class BaseNode(ABC):
return {
**wrapped_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
}
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
@@ -220,7 +274,7 @@ class BaseNode(ABC):
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
"""Execute node with error handling and output wrapping (streaming)
This method is called by the Executor and is responsible for:
@@ -241,6 +295,11 @@ class BaseNode(ABC):
Yields:
State updates with streaming buffer and final result
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"跳过节点{self.node_id}")
return
import time
start_time = time.time()
@@ -358,7 +417,6 @@ class BaseNode(ABC):
state_update = {
**final_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},
@@ -377,7 +435,7 @@ class BaseNode(ABC):
# Finally yield state update
# LangGraph will merge this into state
yield state_update
yield state_update | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
@@ -427,12 +485,13 @@ class BaseNode(ABC):
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
final_output = {
"node_outputs": {self.node_id: node_output},
}
if self.variable_updater:
final_output = final_output | {"variables": state["variables"]}
return final_output
def _wrap_error(
self,

View File

@@ -26,6 +26,9 @@ class NodeType(StrEnum):
MEMORY_WRITE = "memory-write"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
class ComparisonOperator(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"

View File

@@ -1445,7 +1445,7 @@ class AppService:
target_workspace_ids: List[uuid.UUID],
user_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
) -> AppShare:
) -> list[AppShare]:
"""分享应用到其他工作空间
Args: