diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index dc0d6922..3c33ad6e 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -9,18 +9,15 @@ LangChain Agent 封装 """ import os import time -import asyncio from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from langchain_core.tools import BaseTool from langchain.agents import create_agent -from app.core.memory.agent.mcp_server.services import session_service from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType from app.core.logging_config import get_business_logger -from app.services.memory_agent_service import MemoryAgentService from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 80d5316a..75a9cb0b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -92,7 +92,7 @@ class WorkflowExecutor: - def build_graph(self) -> CompiledStateGraph: + def build_graph(self,stream=False) -> CompiledStateGraph: """构建 LangGraph Returns: @@ -122,12 +122,19 @@ class WorkflowExecutor: if node_instance: # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 - def make_node_func(inst): - async def node_func(state: WorkflowState): + if stream: + # 流式模式:创建 async generator 函数 + # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state + async def node_func(state: WorkflowState, inst=node_instance): + async for item in inst.run_stream(state): + yield item + workflow.add_node(node_id, node_func) + else: + # 非流式模式:创建 async function + async def node_func(state: WorkflowState, inst=node_instance): return await inst.run(state) - return node_func + workflow.add_node(node_id, node_func) - workflow.add_node(node_id, make_node_func(node_instance)) logger.debug(f"添加节点: {node_id} (type={node_type})") # 3. 添加边 @@ -276,12 +283,13 @@ class WorkflowExecutor: ): """执行工作流(流式) - 手动执行节点以支持细粒度的流式输出: - - workflow_start: 工作流开始 - - node_start: 节点开始执行 - - node_chunk: LLM 节点的流式输出片段(逐 token) - - node_complete: 节点执行完成 - - workflow_complete: 工作流完成 + 使用 stream_mode="updates" 来获取每个节点的 state 更新。 + 节点的 generator 会 yield 多个值: + - 中间的 chunk 事件(带 type="chunk") + - 最后的 state 更新(纯字典,包含 node_outputs 等) + + LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。 + 我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。 Args: input_data: 输入数据 @@ -289,27 +297,47 @@ class WorkflowExecutor: Yields: 流式事件 """ - # logger.info(f"开始执行工作流: execution_id={self.execution_id}") # 记录开始时间 start_time = datetime.datetime.now() # 1. 构建图 - graph = self.build_graph() + graph = self.build_graph(True) # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) # 3. 执行工作流 try: - async for chunk in graph.astream( + async for mode, event in graph.astream( initial_state, - # subgraphs=True, - stream_mode="updates", + stream_mode=["updates","messages"], ): - # print(chunk) - yield chunk + # print("刚才跑的节点:", event[0]) + # # 通过图结构就能算出“接下来是谁” + # print("接下来可能跑:", graph.get_next(event[0])) + # print("="*50) + # # print("mode",mode) + # print("event",event) + # print("="*50) + # event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk + for node_id, update in event.items(): + print("="*50) + print("node_id",node_id) + print("update",update) + + print("="*50) + if isinstance(update, dict) and update.get("type") == "chunk": + # 这是流式 chunk,转发给客户端 + yield { + "type": "node_chunk", + "node_id": update.get("node_id"), + "chunk": update.get("content") + } + # 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理 + print(event) + yield event except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index d17cc1fd..5674655a 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -209,11 +209,15 @@ class BaseNode(ABC): 3. 将业务数据包装成标准输出格式 4. 错误处理 + 注意:在流式模式下,我们需要: + - yield 中间的 chunk 事件(用于实时显示) + - 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state) + Args: state: 工作流状态 Yields: - 标准化的流式事件 + 标准化的流式事件和最终的 state 更新 """ import time @@ -263,27 +267,39 @@ class BaseNode(ABC): elapsed_time = time.time() - start_time + # 提取处理后的输出(调用子类的 _extract_output) + extracted_output = self._extract_output(final_result) + # 包装最终结果 final_output = self._wrap_output(final_result, elapsed_time, state) - yield { - "type": "complete", - **final_output + + # 将提取后的输出存储到运行时变量中(供后续节点快速访问) + if isinstance(extracted_output, dict): + runtime_var = extracted_output + else: + runtime_var = {"output": extracted_output} + + # 构建完整的 state 更新(包含 node_outputs 和 runtime_vars) + state_update = { + **final_output, + "runtime_vars": { + self.node_id: runtime_var + } } + + # 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中) + yield state_update except TimeoutError: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - yield { - "type": "error", - **self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) - } + error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + yield error_output except Exception as e: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - yield { - "type": "error", - **self._wrap_error(str(e), elapsed_time, state) - } + error_output = self._wrap_error(str(e), elapsed_time, state) + yield error_output def _wrap_output( self, diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index ad028f31..6ee56dde 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -30,11 +30,11 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - pool = self.get_variable_pool(state) + # pool = self.get_variable_pool(state) - print("="*20) - print( pool.get("start.test")) - print("="*20) + # print("="*20) + # print( pool.get("start.test")) + # print("="*20) # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index cf665ff1..295ae583 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -63,7 +63,7 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ - def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]: + def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: @@ -125,16 +125,19 @@ class LLMNode(BaseNode): model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) + print("="*50) + print("stream",stream) + print("="*50) llm = RedBearLLM( RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + extra_params={"streaming": stream} ), type=model_type ) - return llm, prompt_or_messages async def execute(self, state: WorkflowState) -> AIMessage: @@ -146,13 +149,12 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - llm, prompt_or_messages = self._prepare_llm(state) + llm, prompt_or_messages = self._prepare_llm(state,True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") # 调用 LLM(支持字符串或消息列表) response = await llm.ainvoke(prompt_or_messages) - # 提取内容 if hasattr(response, 'content'): content = response.content @@ -199,47 +201,47 @@ class LLMNode(BaseNode): } return None - async def execute_stream(self, state: WorkflowState): - """流式执行 LLM 调用 + # async def execute_stream(self, state: WorkflowState): + # """流式执行 LLM 调用 - Args: - state: 工作流状态 + # Args: + # state: 工作流状态 - Yields: - 文本片段(chunk)或完成标记 - """ - llm, prompt_or_messages = self._prepare_llm(state) + # Yields: + # 文本片段(chunk)或完成标记 + # """ + # llm, prompt_or_messages = self._prepare_llm(state,True) - logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + # logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") - # 累积完整响应 - full_response = "" - last_chunk = None + # # 累积完整响应 + # full_response = "" + # last_chunk = None - # 调用 LLM(流式,支持字符串或消息列表) - async for chunk in llm.astream(prompt_or_messages): - # 提取内容 - if hasattr(chunk, 'content'): - content = chunk.content - else: - content = str(chunk) + # # 调用 LLM(流式,支持字符串或消息列表) + # async for chunk in llm.astream(prompt_or_messages): + # # 提取内容 + # if hasattr(chunk, 'content'): + # content = chunk.content + # else: + # content = str(chunk) - full_response += content - last_chunk = chunk - - # 流式返回每个文本片段 - yield content + # full_response += content + # last_chunk = chunk + # logger.info(f"节点 {self.node_id} LLM : {content}") + # # 流式返回每个文本片段 + # yield content - logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") + # logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") - # 构建完整的 AIMessage(包含元数据) - if isinstance(last_chunk, AIMessage): - final_message = AIMessage( - content=full_response, - response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - ) - else: - final_message = AIMessage(content=full_response) + # # 构建完整的 AIMessage(包含元数据) + # if isinstance(last_chunk, AIMessage): + # final_message = AIMessage( + # content=full_response, + # response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} + # ) + # else: + # final_message = AIMessage(content=full_response) - # yield 完成标记 - yield {"__final__": True, "result": final_message} + # # yield 完成标记 + # yield {"__final__": True, "result": final_message} diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 089f2c07..9ef9dbb1 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -385,7 +385,7 @@ class LLMRouter: # 获取 API Key 配置 api_key_config = self.db.query(ModelApiKey).filter( ModelApiKey.model_config_id == self.routing_model_config.id, - ModelApiKey.is_active == True + ModelApiKey.is_active ).first() if not api_key_config: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index f0b71824..b48edfdd 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -1,29 +1,27 @@ """ 工作流服务层 """ +import datetime import json import logging import uuid -import datetime from typing import Any, Annotated -from sqlalchemy.orm import Session from fastapi import Depends +from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.workflow.validator import validate_workflow_config +from app.db import get_db from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, - WorkflowNodeExecutionRepository, - get_workflow_config_repository, - get_workflow_execution_repository, - get_workflow_node_execution_repository + WorkflowNodeExecutionRepository ) -from app.core.workflow.validator import validate_workflow_config -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.db import get_db from app.schemas import DraftRunRequest +from app.utils.sse_utils import format_sse_message logger = logging.getLogger(__name__) @@ -81,7 +79,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -140,7 +138,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -166,7 +164,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -195,8 +193,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: return False - - self.config_repo.delete(config.id) + config.is_active = False logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") return True @@ -245,7 +242,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -359,7 +356,7 @@ class WorkflowService: execution = self.get_execution(execution_id) if not execution: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"执行记录不存在: execution_id={execution_id}" ) @@ -474,11 +471,9 @@ class WorkflowService: } # 4. 获取工作空间 ID(从 app 获取) - from app.models import App - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream + from app.core.workflow.executor import execute_workflow try: # 更新状态为运行中 @@ -595,17 +590,18 @@ class WorkflowService: } # 4. 获取工作空间 ID(从 app 获取) - from app.models import App # 5. 流式执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") # 发送开始事件 - yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n" + yield format_sse_message("workflow_start", { + "execution_id": execution.execution_id, + "conversation_id_uuid": str(conversation_id_uuid), + }) # 调用流式执行 async for event in self._run_workflow_stream( @@ -621,7 +617,10 @@ class WorkflowService: yield f"data: {json.dumps(cleaned_event)}\n\n" # 发送完成事件 - yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n" + yield format_sse_message("workflow_end", { + "execution_id": execution.execution_id, + "conversation_id_uuid": str(conversation_id_uuid), + }) except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) @@ -660,7 +659,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -687,12 +686,12 @@ class WorkflowService: app = self.db.query(App).filter(App.id == app_id).first() if not app: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"应用不存在: app_id={app_id}" ) # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream + from app.core.workflow.executor import execute_workflow try: # 更新状态为运行中 @@ -750,7 +749,7 @@ class WorkflowService: error_message=str(e) ) raise BusinessException( - error_code=BizCode.INTERNAL_ERROR, + code=BizCode.INTERNAL_ERROR, message=f"工作流执行失败: {str(e)}" ) @@ -820,26 +819,26 @@ class WorkflowService: yield event # 收集输出数据 - if event.get("type") == "node_complete": - node_data = event.get("data", {}) - node_outputs = node_data.get("node_outputs", {}) - output_data.update(node_outputs) - - # 处理完成事件 - if event.get("type") == "workflow_complete": - self.update_execution_status( - execution_id, - "completed", - output_data=output_data - ) - - # 处理错误事件 - if event.get("type") == "workflow_error": - self.update_execution_status( - execution_id, - "failed", - error_message=event.get("error") - ) + # if event.get("type") == "node_complete": + # node_data = event.get("data", {}) + # node_outputs = node_data.get("node_outputs", {}) + # output_data.update(node_outputs) + # + # # 处理完成事件 + # if event.get("type") == "workflow_complete": + # self.update_execution_status( + # execution_id, + # "completed", + # output_data=output_data + # ) + # + # # 处理错误事件 + # if event.get("type") == "workflow_error": + # self.update_execution_status( + # execution_id, + # "failed", + # error_message=event.get("error") + # ) except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)