[modify] llm & end logic
This commit is contained in:
@@ -93,18 +93,19 @@ class WorkflowExecutor:
|
||||
|
||||
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> dict[str, str]:
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""分析 End 节点的前缀配置
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
|
||||
Returns:
|
||||
字典:{上游节点ID: End节点前缀}
|
||||
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
@@ -146,6 +147,9 @@ class WorkflowExecutor:
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
@@ -154,7 +158,8 @@ class WorkflowExecutor:
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
return prefixes
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
@@ -164,8 +169,8 @@ class WorkflowExecutor:
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 分析 End 节点的前缀配置
|
||||
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
|
||||
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
@@ -193,6 +198,12 @@ class WorkflowExecutor:
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||
|
||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||
if stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
if stream:
|
||||
@@ -401,13 +412,16 @@ class WorkflowExecutor:
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
|
||||
event_type = data.get("type", "node_chunk") # 默认为 node_chunk
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"type": event_type, # "message" or "node_chunk"
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index")
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix")
|
||||
}
|
||||
|
||||
elif mode == "debug":
|
||||
|
||||
@@ -240,6 +240,14 @@ class BaseNode(ABC):
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
@@ -267,6 +275,7 @@ class BaseNode(ABC):
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
@@ -294,6 +303,7 @@ class BaseNode(ABC):
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
|
||||
@@ -236,6 +236,7 @@ class EndNode(BaseNode):
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
|
||||
@@ -234,8 +234,9 @@ class LLMNode(BaseNode):
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
|
||||
Reference in New Issue
Block a user