diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 464e602b..17ad70a7 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -567,6 +567,7 @@ async def chat( with get_db_read() as db: source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id) config.id = source_config.id + config.id = uuid.UUID(config.id) if payload.stream: async def event_generator(): async for event in app_chat_service.workflow_chat_stream( diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c048f447..ad03fec1 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -8,6 +8,7 @@ import logging import uuid from typing import Any +from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph from app.core.workflow.graph_builder import GraphBuilder @@ -53,11 +54,11 @@ class WorkflowExecutor: self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) - self.checkpoint_config = { - "configurable": { + self.checkpoint_config = RunnableConfig( + configurable={ "thread_id": uuid.uuid4(), } - } + ) def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: """准备初始状态(注入系统变量和会话变量) @@ -214,13 +215,13 @@ class WorkflowExecutor: return { "status": "completed", "output": final_output, + "variables": result.get("variables", {}), "node_outputs": node_outputs, "messages": result.get("messages", []), "conversation_id": conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, "error": result.get("error"), - "variables": result.get("variables", {}), } def build_graph(self, stream=False) -> CompiledStateGraph: @@ -326,11 +327,10 @@ class WorkflowExecutor: } # 1. 构建图 - graph = self.build_graph(True) + graph = self.build_graph(stream=True) # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - # 3. Execute workflow try: chunk_count = 0 @@ -346,14 +346,16 @@ class WorkflowExecutor: mode, data = event else: # Unexpected format, log and skip - logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}") + logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}" + f"- execution_id: {self.execution_id}") continue if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) chunk_count += 1 event_type = data.get("type", "node_chunk") # "message" or "node_chunk" - logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}") + logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" + f"- execution_id: {self.execution_id}") yield { "event": event_type, # "message" or "node_chunk" "data": { @@ -380,7 +382,8 @@ class WorkflowExecutor: variables_sys = variables.get("sys", {}) conversation_id = input_data.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] Node starts execution: {node_name}") + logger.info(f"[NODE-START] Node starts execution: {node_name} " + f"- execution_id: {self.execution_id}") yield { "event": "node_start", @@ -399,7 +402,8 @@ class WorkflowExecutor: variables_sys = variables.get("sys", {}) conversation_id = input_data.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] Node execution completed: {node_name}") + logger.info(f"[NODE-END] Node execution completed: {node_name} " + f"- execution_id: {self.execution_id}") yield { "event": "node_end", @@ -407,13 +411,15 @@ class WorkflowExecutor: "node_id": node_name, "conversation_id": conversation_id, "execution_id": execution_id, - "timestamp": data.get("timestamp") + "timestamp": data.get("timestamp"), + "state": result.get("node_outputs", {}).get(node_name), } } elif mode == "updates": # Handle state updates - store final state - logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") + logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " + f"- execution_id: {self.execution_id}") # 计算耗时 end_time = datetime.datetime.now() @@ -421,7 +427,7 @@ class WorkflowExecutor: result = graph.get_state(self.checkpoint_config).values logger.info( f"Workflow execution completed (streaming), " - f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s" + f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" ) # 发送 workflow_end 事件 @@ -449,7 +455,8 @@ class WorkflowExecutor: } } - def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: + @staticmethod + def _extract_final_output(node_outputs: dict[str, Any]) -> str | None: """从节点输出中提取最终输出 优先级: @@ -473,7 +480,8 @@ class WorkflowExecutor: return None - def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None: + @staticmethod + def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None: """聚合所有节点的 token 使用情况 Args: diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index 4ae8e118..66c3a700 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -21,6 +21,7 @@ class IterationRuntime: optional parallel execution, flattening of output, and loop control via the workflow state. """ + def __init__( self, graph: CompiledStateGraph, @@ -87,6 +88,7 @@ class IterationRuntime: self.result.append(output) if not result["looping"]: self.looping = False + return result def _create_iteration_tasks(self, array_obj, idx): """ @@ -124,7 +126,7 @@ class IterationRuntime: array_obj = VariablePool(self.state).get(input_expression) if not isinstance(array_obj, list): raise RuntimeError("Cannot iterate over a non-list variable") - + child_state = [] idx = 0 if self.typed_config.parallel: # Execute iterations in parallel batches @@ -132,15 +134,14 @@ class IterationRuntime: tasks = self._create_iteration_tasks(array_obj, idx) logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count - await asyncio.gather(*tasks) - logger.info(f"Iteration node {self.node_id}: execution completed") - return self.result + child_state.extend(await asyncio.gather(*tasks)) else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + child_state.append(result) output = VariablePool(result).get(self.output_value) if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) @@ -150,5 +151,8 @@ class IterationRuntime: self.looping = False idx += 1 - logger.info(f"Iteration node {self.node_id}: execution completed") - return self.result + logger.info(f"Iteration node {self.node_id}: execution completed") + return { + "output": self.result, + "__child_state": child_state + } diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 2e2ab4fb..38d4b21c 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -67,7 +67,9 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } self.state["node_outputs"][self.node_id] = { @@ -76,7 +78,9 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } loopstate = WorkflowState( @@ -171,10 +175,11 @@ class LoopRuntime: """ loopstate = self._init_loop_state() loop_time = self.typed_config.max_loop + child_state = [] while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0: logger.info(f"loop node {self.node_id}: running") - await self.graph.ainvoke(loopstate) + child_state.append(await self.graph.ainvoke(loopstate)) loop_time -= 1 logger.info(f"loop node {self.node_id}: execution completed") - return loopstate["runtime_vars"][self.node_id] + return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index e25bd35d..061a0f6a 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -269,12 +269,16 @@ class LLMNode(BaseNode): chunk_count = 0 # 调用 LLM(流式,支持字符串或消息列表) - async for chunk in llm.astream(prompt_or_messages): + last_meta_data = {} + async for chunk in llm.astream(prompt_or_messages, stream_usage=True): # 提取内容 if hasattr(chunk, 'content'): content = chunk.content else: content = str(chunk) + if hasattr(chunk, 'response_metadata'): + if chunk.response_metadata: + last_meta_data = chunk.response_metadata # 只有当内容不为空时才处理 if content: @@ -288,13 +292,10 @@ class LLMNode(BaseNode): logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") # 构建完整的 AIMessage(包含元数据) - if isinstance(last_chunk, AIMessage): - final_message = AIMessage( - content=full_response, - response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - ) - else: - final_message = AIMessage(content=full_response) + final_message = AIMessage( + content=full_response, + response_metadata=last_meta_data + ) # yield 完成标记 yield {"__final__": True, "result": final_message} diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 6cff6844..b7d5df02 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -14,15 +14,14 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config from app.db import get_db -from app.models.conversation_model import Message from app.models.workflow_model import WorkflowConfig, WorkflowExecution -from app.repositories.conversation_repository import MessageRepository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) from app.schemas import DraftRunRequest +from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ class WorkflowService: self.config_repo = WorkflowConfigRepository(db) self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) - self.message_repo = MessageRepository(db) + self.conversation_service = ConversationService(db) # ==================== 配置管理 ==================== @@ -340,6 +339,7 @@ class WorkflowService: self, execution_id: str, status: str, + token_usage: int | None = None, output_data: dict[str, Any] | None = None, error_message: str | None = None, error_node_id: str | None = None @@ -349,6 +349,7 @@ class WorkflowService: Args: execution_id: 执行 ID status: 状态 + token_usage: token消耗 output_data: 输出数据 error_message: 错误信息 error_node_id: 出错节点 ID @@ -367,6 +368,8 @@ class WorkflowService: ) execution.status = status + if token_usage is not None: + execution.token_usage = token_usage if output_data is not None: execution.output_data = convert_uuids_to_str(output_data) if error_message is not None: @@ -513,20 +516,20 @@ class WorkflowService: # 更新执行结果 if result.get("status") == "completed": + token_usage = result.get("token_usage", {}) or {} self.update_execution_status( execution.execution_id, "completed", - output_data=result + output_data=result, + token_usage=token_usage.get("total_tokens", None) ) final_messages = result.get("messages", [])[init_message_length:] for message in final_messages: - message_obj = Message( + self.conversation_service.add_message( conversation_id=conversation_id_uuid, role=message["role"], - content=message["content"], + content=message["content"] ) - self.message_repo.add_message(message_obj) - self.db.commit() logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: @@ -662,21 +665,21 @@ class WorkflowService: if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") + token_usage = event.get("data", {}).get("token_usage", {}) or {} if status == "completed": self.update_execution_status( execution.execution_id, "completed", - output_data=event.get("data") + output_data=event.get("data"), + token_usage=token_usage.get("total_tokens", None) ) final_messages = event.get("data", {}).get("messages", [])[init_message_length:] for message in final_messages: - message_obj = Message( + self.conversation_service.add_message( conversation_id=conversation_id_uuid, role=message["role"], - content=message["content"], + content=message["content"] ) - self.message_repo.add_message(message_obj) - self.db.commit() logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": @@ -793,10 +796,12 @@ class WorkflowService: # 更新执行结果 if result.get("status") == "completed": + token_usage = result.get("data").get("token_usage", {}) or {} self.update_execution_status( execution.execution_id, "completed", - output_data=result.get("node_outputs", {}) + output_data=result.get("node_outputs", {}), + token_usage=token_usage.get("total_tokens", None) ) else: self.update_execution_status( @@ -891,13 +896,14 @@ class WorkflowService: ): # 直接转发事件(executor 已经返回正确格式) if event.get("event") == "workflow_end": - + token_usage = event.get("data").get("token_usage", {}) or {} status = event.get("data", {}).get("status") if status == "completed": self.update_execution_status( execution_id, "completed", - output_data=event.get("data") + output_data=event.get("data"), + token_usage=token_usage.get("total_tokens", None) ) elif status == "failed": self.update_execution_status(