[modify] llm & end logic

This commit is contained in:
Mark
2025-12-20 17:37:36 +08:00
parent 36b36b729b
commit 43a427bac7
4 changed files with 35 additions and 9 deletions

View File

@@ -93,18 +93,19 @@ class WorkflowExecutor:
def _analyze_end_node_prefixes(self) -> dict[str, str]:
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
字典:{上游节点ID: End节点前缀}
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
@@ -146,6 +147,9 @@ class WorkflowExecutor:
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
@@ -154,7 +158,8 @@ class WorkflowExecutor:
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
return prefixes
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
@@ -164,8 +169,8 @@ class WorkflowExecutor:
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
@@ -193,6 +198,12 @@ class WorkflowExecutor:
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
@@ -401,13 +412,16 @@ class WorkflowExecutor:
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
event_type = data.get("type", "node_chunk") # 默认为 node_chunk
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
yield {
"type": "node_chunk",
"type": event_type, # "message" or "node_chunk"
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index")
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix")
}
elif mode == "debug":

View File

@@ -240,6 +240,14 @@ class BaseNode(ABC):
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
@@ -267,6 +275,7 @@ class BaseNode(ABC):
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
@@ -294,6 +303,7 @@ class BaseNode(ABC):
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,

View File

@@ -236,6 +236,7 @@ class EndNode(BaseNode):
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)

View File

@@ -234,8 +234,9 @@ class LLMNode(BaseNode):
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,