[modify] llm & end logic
This commit is contained in:
@@ -93,18 +93,19 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _analyze_end_node_prefixes(self) -> dict[str, str]:
|
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||||
"""分析 End 节点的前缀配置
|
"""分析 End 节点的前缀配置
|
||||||
|
|
||||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||||
提取该引用之前的前缀部分。
|
提取该引用之前的前缀部分。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
字典:{上游节点ID: End节点前缀}
|
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
prefixes = {}
|
prefixes = {}
|
||||||
|
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||||
|
|
||||||
# 找到所有 End 节点
|
# 找到所有 End 节点
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
@@ -146,6 +147,9 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||||
|
|
||||||
|
# 标记这个节点为"相邻且被引用"
|
||||||
|
adjacent_and_referenced.add(referenced_node_id)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
prefixes[referenced_node_id] = prefix
|
prefixes[referenced_node_id] = prefix
|
||||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||||
@@ -154,7 +158,8 @@ class WorkflowExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||||
return prefixes
|
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||||
|
return prefixes, adjacent_and_referenced
|
||||||
|
|
||||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
@@ -164,8 +169,8 @@ class WorkflowExecutor:
|
|||||||
"""
|
"""
|
||||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||||
|
|
||||||
# 分析 End 节点的前缀配置
|
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
||||||
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
|
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
||||||
|
|
||||||
# 1. 创建状态图
|
# 1. 创建状态图
|
||||||
workflow = StateGraph(WorkflowState)
|
workflow = StateGraph(WorkflowState)
|
||||||
@@ -193,6 +198,12 @@ class WorkflowExecutor:
|
|||||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||||
|
|
||||||
|
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||||
|
if stream:
|
||||||
|
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||||
|
if node_id in adjacent_and_referenced:
|
||||||
|
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||||
|
|
||||||
# 包装节点的 run 方法
|
# 包装节点的 run 方法
|
||||||
# 使用函数工厂避免闭包问题
|
# 使用函数工厂避免闭包问题
|
||||||
if stream:
|
if stream:
|
||||||
@@ -401,13 +412,16 @@ class WorkflowExecutor:
|
|||||||
if mode == "custom":
|
if mode == "custom":
|
||||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
|
event_type = data.get("type", "node_chunk") # 默认为 node_chunk
|
||||||
|
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
|
||||||
yield {
|
yield {
|
||||||
"type": "node_chunk",
|
"type": event_type, # "message" or "node_chunk"
|
||||||
"node_id": data.get("node_id"),
|
"node_id": data.get("node_id"),
|
||||||
"chunk": data.get("chunk"),
|
"chunk": data.get("chunk"),
|
||||||
"full_content": data.get("full_content"),
|
"full_content": data.get("full_content"),
|
||||||
"chunk_index": data.get("chunk_index")
|
"chunk_index": data.get("chunk_index"),
|
||||||
|
"is_prefix": data.get("is_prefix"),
|
||||||
|
"is_suffix": data.get("is_suffix")
|
||||||
}
|
}
|
||||||
|
|
||||||
elif mode == "debug":
|
elif mode == "debug":
|
||||||
|
|||||||
@@ -240,6 +240,14 @@ class BaseNode(ABC):
|
|||||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||||
is_end_node = self.node_type == "end"
|
is_end_node = self.node_type == "end"
|
||||||
|
|
||||||
|
# Check if this node is adjacent to End node (for message type)
|
||||||
|
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||||
|
|
||||||
|
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||||
|
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||||
|
|
||||||
|
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||||
|
|
||||||
# Accumulate complete result (for final wrapping)
|
# Accumulate complete result (for final wrapping)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_result = None
|
final_result = None
|
||||||
@@ -267,6 +275,7 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
# 1. Send via stream writer (for real-time client updates)
|
# 1. Send via stream writer (for real-time client updates)
|
||||||
writer({
|
writer({
|
||||||
|
"type": chunk_type, # "message" or "node_chunk"
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": item,
|
"chunk": item,
|
||||||
"full_content": full_content,
|
"full_content": full_content,
|
||||||
@@ -294,6 +303,7 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
# Send chunks for all nodes
|
# Send chunks for all nodes
|
||||||
writer({
|
writer({
|
||||||
|
"type": chunk_type, # "message" or "node_chunk"
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": chunk_str,
|
"chunk": chunk_str,
|
||||||
"full_content": full_content,
|
"full_content": full_content,
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ class EndNode(BaseNode):
|
|||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
writer({
|
writer({
|
||||||
|
"type": "message", # End 节点的输出使用 message 类型
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": suffix,
|
"chunk": suffix,
|
||||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||||
|
|||||||
@@ -234,8 +234,9 @@ class LLMNode(BaseNode):
|
|||||||
rendered_prefix = self._render_template(end_prefix, state)
|
rendered_prefix = self._render_template(end_prefix, state)
|
||||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||||
|
|
||||||
# 提前发送 End 节点的前缀
|
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||||
writer({
|
writer({
|
||||||
|
"type": "message", # End 相关的内容都是 message 类型
|
||||||
"node_id": "end", # 标记为 end 节点的输出
|
"node_id": "end", # 标记为 end 节点的输出
|
||||||
"chunk": rendered_prefix,
|
"chunk": rendered_prefix,
|
||||||
"full_content": rendered_prefix,
|
"full_content": rendered_prefix,
|
||||||
|
|||||||
Reference in New Issue
Block a user