Merge pull request #142 from SuanmoSuanyangTechnology/feature/workflow-release

Fix workflow release issues and enhance token metrics & loop node outputs
This commit is contained in:
Mark
2026-01-19 15:46:12 +08:00
committed by GitHub
15 changed files with 402 additions and 339 deletions

View File

@@ -8,6 +8,7 @@ import logging
import uuid
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.graph_builder import GraphBuilder
@@ -53,11 +54,11 @@ class WorkflowExecutor:
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
self.checkpoint_config = {
"configurable": {
self.checkpoint_config = RunnableConfig(
configurable={
"thread_id": uuid.uuid4(),
}
}
)
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
@@ -214,13 +215,13 @@ class WorkflowExecutor:
return {
"status": "completed",
"output": final_output,
"variables": result.get("variables", {}),
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error"),
"variables": result.get("variables", {}),
}
def build_graph(self, stream=False) -> CompiledStateGraph:
@@ -326,11 +327,10 @@ class WorkflowExecutor:
}
# 1. 构建图
graph = self.build_graph(True)
graph = self.build_graph(stream=True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. Execute workflow
try:
chunk_count = 0
@@ -346,14 +346,16 @@ class WorkflowExecutor:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
f"- execution_id: {self.execution_id}")
continue
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
f"- execution_id: {self.execution_id}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
@@ -380,7 +382,8 @@ class WorkflowExecutor:
variables_sys = variables.get("sys", {})
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
logger.info(f"[NODE-START] Node starts execution: {node_name} "
f"- execution_id: {self.execution_id}")
yield {
"event": "node_start",
@@ -399,7 +402,8 @@ class WorkflowExecutor:
variables_sys = variables.get("sys", {})
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
logger.info(f"[NODE-END] Node execution completed: {node_name} "
f"- execution_id: {self.execution_id}")
yield {
"event": "node_end",
@@ -407,13 +411,15 @@ class WorkflowExecutor:
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
"timestamp": data.get("timestamp"),
"state": result.get("node_outputs", {}).get(node_name),
}
}
elif mode == "updates":
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
f"- execution_id: {self.execution_id}")
# 计算耗时
end_time = datetime.datetime.now()
@@ -421,7 +427,7 @@ class WorkflowExecutor:
result = graph.get_state(self.checkpoint_config).values
logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
)
# 发送 workflow_end 事件
@@ -449,7 +455,8 @@ class WorkflowExecutor:
}
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
@staticmethod
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
@@ -473,7 +480,8 @@ class WorkflowExecutor:
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
@staticmethod
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:

View File

@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# List of messages (append mode)
messages: list[dict[str, str]]
messages: Annotated[list[dict[str, str]], lambda x, y: y]
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list

View File

@@ -21,6 +21,7 @@ class IterationRuntime:
optional parallel execution, flattening of output, and loop control via
the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
@@ -87,6 +88,7 @@ class IterationRuntime:
self.result.append(output)
if not result["looping"]:
self.looping = False
return result
def _create_iteration_tasks(self, array_obj, idx):
"""
@@ -124,7 +126,7 @@ class IterationRuntime:
array_obj = VariablePool(self.state).get(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
child_state = []
idx = 0
if self.typed_config.parallel:
# Execute iterations in parallel batches
@@ -132,15 +134,14 @@ class IterationRuntime:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
await asyncio.gather(*tasks)
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result
child_state.extend(await asyncio.gather(*tasks))
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
child_state.append(result)
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
@@ -150,5 +151,8 @@ class IterationRuntime:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result
logger.info(f"Iteration node {self.node_id}: execution completed")
return {
"output": self.result,
"__child_state": child_state
}

View File

@@ -67,7 +67,9 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
@@ -76,7 +78,9 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
@@ -171,10 +175,11 @@ class LoopRuntime:
"""
loopstate = self._init_loop_state()
loop_time = self.typed_config.max_loop
child_state = []
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
child_state.append(await self.graph.ainvoke(loopstate))
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id]
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}

View File

@@ -10,9 +10,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -96,7 +95,7 @@ class KnowledgeRetrievalNode(BaseNode):
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
share_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
@@ -105,7 +104,7 @@ class KnowledgeRetrievalNode(BaseNode):
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)

View File

@@ -66,7 +66,7 @@ class LLMNodeConfig(BaseNodeConfig):
)
memory: MemoryWindowSetting = Field(
...,
default_factory=MemoryWindowSetting,
description="对话上下文窗口"
)

View File

@@ -85,6 +85,7 @@ class LLMNode(BaseNode):
"""
# 1. 处理消息格式(优先使用 messages
self.typed_config = LLMNodeConfig(**self.config)
messages_config = self.typed_config.messages
if messages_config:
@@ -167,7 +168,7 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
self.typed_config = LLMNodeConfig(**self.config)
# self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
@@ -269,12 +270,16 @@ class LLMNode(BaseNode):
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
last_meta_data = {}
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata'):
if chunk.response_metadata:
last_meta_data = chunk.response_metadata
# 只有当内容不为空时才处理
if content:
@@ -288,13 +293,10 @@ class LLMNode(BaseNode):
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
final_message = AIMessage(
content=full_response,
response_metadata=last_meta_data
)
# yield 完成标记
yield {"__final__": True, "result": final_message}