From d8fcea856460c11b39e3cbb874d586018e850aac Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 16:03:41 +0800 Subject: [PATCH] [fix] model support stream --- api/app/core/models/llm.py | 252 ++++++++++++++++++----- api/app/core/workflow/executor.py | 118 +++++++---- api/app/core/workflow/nodes/base_node.py | 3 + api/app/core/workflow/nodes/end/node.py | 50 ++++- api/app/core/workflow/nodes/llm/node.py | 88 ++++---- 5 files changed, 377 insertions(+), 134 deletions(-) diff --git a/api/app/core/models/llm.py b/api/app/core/models/llm.py index 5808d31a..7cd12faa 100644 --- a/api/app/core/models/llm.py +++ b/api/app/core/models/llm.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Iterator, AsyncIterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain_core.language_models import BaseLLM -from langchain_core.outputs import LLMResult +from langchain_core.outputs import LLMResult, GenerationChunk from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class from app.models.models_model import ModelType @@ -10,21 +10,36 @@ from app.models.models_model import ModelType class RedBearLLM(BaseLLM): """ - RedBear LLM 模型包装器 - 完全动态代理实现 + RedBear LLM Model Wrapper - 这个包装器自动将所有方法调用委托给内部模型, - 同时提供优雅的回退机制和错误处理。 + This wrapper provides a unified interface to access different LLM providers, + while maintaining all LangChain functionality, including streaming output. + + Features: + - Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.) + - Full streaming output support + - Elegant error handling and fallback mechanism + - Automatic proxying of all underlying model methods and attributes """ - def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM): - self._model = self._create_model(config, type) + def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM): + """Initialize RedBear LLM wrapper + + Args: + config: Model configuration + type: Model type (LLM or CHAT) + """ + super().__init__() self._config = config + self._model = self._create_model(config, type) @property def _llm_type(self) -> str: - """返回LLM类型标识符""" - return self._model._llm_type + """Return LLM type identifier""" + return getattr(self._model, '_llm_type', 'redbear_llm') + # ==================== Core Methods (Required by BaseLLM) ==================== + def _generate( self, prompts: List[str], @@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> LLMResult: - """同步生成文本""" + """Synchronous text generation (required by BaseLLM)""" return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs) async def _agenerate( @@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any ) -> LLMResult: - """异步生成文本""" + """Asynchronous text generation (required by BaseLLM)""" return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs) - # 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为 + # ==================== Advanced Methods (Support Message Lists) ==================== + def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """直接调用底层模型以支持 ChatPrompt 和消息列表。""" + """Synchronous model invocation + + Supports various input formats including strings and message lists. + Directly delegates to the underlying model to avoid BaseLLM's string conversion. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + **kwargs: Additional arguments + + Returns: + Model response + """ try: return self._model.invoke(input, config=config, **kwargs) except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) if 'invoke' in str(e): + # Underlying model doesn't support invoke, fallback to parent implementation return super().invoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 raise except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 + # Other exceptions are raised directly raise async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """异步直接调用底层模型以支持 ChatPrompt 和消息列表。""" + """Asynchronous model invocation + + Supports various input formats including strings and message lists. + Directly delegates to the underlying model to avoid BaseLLM's string conversion. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + **kwargs: Additional arguments + + Returns: + Model response + """ try: return await self._model.ainvoke(input, config=config, **kwargs) except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) if 'ainvoke' in str(e): + # Underlying model doesn't support ainvoke, fallback to parent implementation return await super().ainvoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 raise except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 + # Other exceptions are raised directly raise - def __getattr__(self, name): - """ - 动态代理:将所有未定义的属性和方法调用委托给内部模型 + # ==================== Streaming Methods (Critical) ==================== + + def stream( + self, + input: Any, + config: Optional[dict] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any + ) -> Iterator[GenerationChunk]: + """Synchronous streaming model invocation - 这是最优雅的包装器实现方式,完全避免了方法重复定义 - """ - # 处理特殊属性以避免递归 - if name in ('__isabstractmethod__', '__dict__', '__class__'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + stop: List of stop words + **kwargs: Additional arguments - # 检查内部模型是否有该属性(使用安全的方式避免递归) + Yields: + GenerationChunk: Generated text chunks + """ + try: + yield from self._model.stream(input, config=config, stop=stop, **kwargs) + except AttributeError as e: + if 'stream' in str(e): + # Underlying model doesn't support stream, fallback to parent implementation + yield from super().stream(input, config=config, stop=stop, **kwargs) + else: + raise + except Exception: + raise + + async def astream( + self, + input: Any, + config: Optional[dict] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any + ) -> AsyncIterator[GenerationChunk]: + """Asynchronous streaming model invocation + + This is the core method for streaming output. It directly proxies to the + underlying model's astream method, maintaining generator characteristics + to ensure each chunk is delivered in real-time. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + stop: List of stop words + **kwargs: Additional arguments + + Yields: + GenerationChunk: Generated text chunks + """ + try: + async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs): + yield chunk + except AttributeError as e: + if 'astream' in str(e): + # Underlying model doesn't support astream, fallback to parent implementation + async for chunk in super().astream(input, config=config, stop=stop, **kwargs): + yield chunk + else: + raise + except Exception: + raise + + # ==================== Dynamic Proxy ==================== + + def __getattr__(self, name: str) -> Any: + """Dynamic proxy: delegate undefined attributes and method calls to internal model + + This method allows RedBearLLM to transparently access all attributes and methods + of the underlying model without explicitly defining each one. + + Args: + name: Attribute or method name + + Returns: + Attribute value or method + + Raises: + AttributeError: If attribute doesn't exist + """ + # Avoid recursion: raise error directly for special attributes + if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Try to get attribute from internal model try: - # 使用 object.__getattribute__ 来安全地检查内部模型的属性 attr = object.__getattribute__(self._model, name) - # 如果是方法,返回一个包装器来处理调用 + # If it's callable (a method) if callable(attr): - # 流式方法直接返回,不包装(保持生成器特性) - if name in ('_stream', '_astream', 'stream', 'astream'): + # Streaming methods are returned directly to maintain generator characteristics + # Note: Although we've explicitly implemented stream/astream, + # this is kept to handle internal methods like _stream/_astream + if name in ('_stream', '_astream'): return attr - # 非流式方法使用包装器处理异常 + # Wrap other methods for easier debugging and error handling def method_wrapper(*args, **kwargs): - return attr(*args, **kwargs) + try: + return attr(*args, **kwargs) + except Exception: + # Can add logging or error handling here + raise - # 保持方法的元信息 + # Preserve method metadata method_wrapper.__name__ = name method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}") return method_wrapper - # 如果是普通属性,直接返回 + # If it's a regular attribute, return directly return attr except AttributeError: - # 内部模型没有该属性,尝试回退实现 + # Internal model doesn't have this attribute either pass - # 检查是否有回退方法(使用安全的方式避免递归) + # Check if there's a fallback method fallback_name = f'_fallback_{name}' try: - fallback_method = object.__getattribute__(self, fallback_name) - return fallback_method + return object.__getattribute__(self, fallback_name) except AttributeError: - # 没有回退方法,抛出适当的错误 pass - # 如果都没有,抛出适当的错误 - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + # Nothing found, raise error + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'. " + f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute." + ) + # ==================== Helper Methods ==================== + def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM: - """创建内部模型实例""" + """Create internal model instance + + Args: + config: Model configuration + type: Model type + + Returns: + Created model instance + """ llm_class = get_provider_llm_class(config, type) model_params = RedBearModelFactory.get_model_params(config) return llm_class(**model_params) - - - \ No newline at end of file + + def get_config(self) -> RedBearModelConfig: + """Get model configuration + + Returns: + Model configuration object + """ + return self._config + + def get_underlying_model(self) -> BaseLLM: + """Get underlying model instance + + Returns: + Underlying model instance + """ + return self._model + + def __repr__(self) -> str: + """Return string representation of the object""" + return ( + f"RedBearLLM(" + f"provider={self._config.provider}, " + f"model={self._config.model_name}, " + f"type={type(self._model).__name__}" + f")" + ) \ No newline at end of file diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 75a9cb0b..8d67dd1e 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -125,17 +125,22 @@ class WorkflowExecutor: if stream: # 流式模式:创建 async generator 函数 # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state - async def node_func(state: WorkflowState, inst=node_instance): - async for item in inst.run_stream(state): - yield item - workflow.add_node(node_id, node_func) + def make_stream_func(inst): + async def node_func(state: WorkflowState): + # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") + async for item in inst.run_stream(state): + yield item + return node_func + workflow.add_node(node_id, make_stream_func(node_instance)) else: # 非流式模式:创建 async function - async def node_func(state: WorkflowState, inst=node_instance): - return await inst.run(state) - workflow.add_node(node_id, node_func) + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + return node_func + workflow.add_node(node_id, make_func(node_instance)) - logger.debug(f"添加节点: {node_id} (type={node_type})") + logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})") # 3. 添加边 # 从 START 连接到 start 节点 @@ -283,13 +288,9 @@ class WorkflowExecutor: ): """执行工作流(流式) - 使用 stream_mode="updates" 来获取每个节点的 state 更新。 - 节点的 generator 会 yield 多个值: - - 中间的 chunk 事件(带 type="chunk") - - 最后的 state 更新(纯字典,包含 node_outputs 等) - - LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。 - 我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。 + 使用多个 stream_mode 来获取: + 1. "updates" - 节点的 state 更新和流式 chunk + 2. "debug" - 节点执行的详细信息(开始/完成时间) Args: input_data: 输入数据 @@ -297,7 +298,7 @@ class WorkflowExecutor: Yields: 流式事件 """ - logger.info(f"开始执行工作流: execution_id={self.execution_id}") + logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") # 记录开始时间 start_time = datetime.datetime.now() @@ -310,34 +311,73 @@ class WorkflowExecutor: # 3. 执行工作流 try: - async for mode, event in graph.astream( + chunk_count = 0 + async for event in graph.astream( initial_state, - stream_mode=["updates","messages"], + stream_mode=["updates", "debug"], ): - # print("刚才跑的节点:", event[0]) - # # 通过图结构就能算出“接下来是谁” - # print("接下来可能跑:", graph.get_next(event[0])) - # print("="*50) - # # print("mode",mode) - # print("event",event) - # print("="*50) - # event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk - for node_id, update in event.items(): - print("="*50) - print("node_id",node_id) - print("update",update) + mode, data = event - print("="*50) - if isinstance(update, dict) and update.get("type") == "chunk": - # 这是流式 chunk,转发给客户端 + if mode == "debug": + # 处理调试信息(节点执行状态) + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + if event_type == "task": + # 节点开始执行 + inputv = payload.get("input", {}) + variables = inputv.get("variables", {}) + variables_sys = variables.get("sys", {}) + conversation_id = variables_sys.get("conversation_id") + execution_id = variables_sys.get("execution_id") + logger.info(f"[DEBUG] 节点开始执行: {node_name}") yield { - "type": "node_chunk", - "node_id": update.get("node_id"), - "chunk": update.get("content") + "type": "node_start", + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") } - # 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理 - print(event) - yield event + elif event_type == "task_result": + # 节点执行完成 + result = payload.get("result", {}) + inputv = result.get("input", {}) + variables = inputv.get("variables", {}) + variables_sys = variables.get("sys", {}) + conversation_id = variables_sys.get("conversation_id") + execution_id = variables_sys.get("execution_id") + logger.info(f"[DEBUG] 节点执行完成: {node_name}") + yield { + "type": "node_end", + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") + } + + elif mode == "updates": + # 处理 state 更新 + # data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk + print("="*50) + print(data) + print("-"*50) + for node_id, update in data.items(): + if isinstance(update, dict) and update.get("type") == "chunk": + # 这是流式 chunk,转发给客户端 + chunk_count += 1 + logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...") + yield { + "type": "node_chunk", + "node_id": update.get("node_id"), + "chunk": update.get("content"), + "full_content": update.get("full_content") + } + else: + logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}") + # 其他情况(state 更新)会被 LangGraph 自动合并到 state + + logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}") except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 5674655a..1d6f1c15 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -245,6 +245,9 @@ class BaseNode(ABC): final_result = item["result"] elif isinstance(item, str): # 字符串是 chunk + # print("="*50) + # print(item) + # print("-"*50) chunks.append(item) yield { "type": "chunk", diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6ee56dde..cba0d649 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -30,11 +30,7 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - # pool = self.get_variable_pool(state) - - # print("="*20) - # print( pool.get("start.test")) - # print("="*20) + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) @@ -46,7 +42,45 @@ class EndNode(BaseNode): total_nodes = len(node_outputs) logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - print("="*20) - print(output) - print("="*20) + return output + + async def execute_stream(self, state: WorkflowState): + """流式执行 end 节点业务逻辑 + + 当 end 节点前面是 LLM 节点时,流式输出其内容。 + + Args: + state: 工作流状态 + + Yields: + 文本片段(chunk)或完成标记 + """ + logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") + + # 获取配置的输出模板 + output_template = self.config.get("output") + + # 如果配置了输出模板,使用模板渲染 + if output_template: + output = self._render_template(output_template, state) + + # 检查输出中是否包含节点引用(如 {{llm_node.output}}) + # 如果包含,则逐字符流式输出 + if output: + # 逐字符流式输出 + for char in output: + yield char + else: + output = "工作流已完成" + for char in output: + yield char + + # 统计信息(用于日志) + node_outputs = state.get("node_outputs", {}) + total_nodes = len(node_outputs) + + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点") + + # yield 完成标记 + yield {"__final__": True, "result": output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 295ae583..bac707d7 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -125,19 +125,22 @@ class LLMNode(BaseNode): model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) - print("="*50) - print("stream",stream) - print("="*50) + # 注意:对于流式输出,需要在模型初始化时设置 streaming=True + extra_params = {"streaming": stream} if stream else {} + llm = RedBearLLM( RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, base_url=api_base, - extra_params={"streaming": stream} + extra_params=extra_params ), type=model_type ) + + logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") + return llm, prompt_or_messages async def execute(self, state: WorkflowState) -> AIMessage: @@ -201,47 +204,54 @@ class LLMNode(BaseNode): } return None - # async def execute_stream(self, state: WorkflowState): - # """流式执行 LLM 调用 + async def execute_stream(self, state: WorkflowState): + """流式执行 LLM 调用 - # Args: - # state: 工作流状态 + Args: + state: 工作流状态 - # Yields: - # 文本片段(chunk)或完成标记 - # """ - # llm, prompt_or_messages = self._prepare_llm(state,True) + Yields: + 文本片段(chunk)或完成标记 + """ + llm, prompt_or_messages = self._prepare_llm(state, True) - # logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") - # # 累积完整响应 - # full_response = "" - # last_chunk = None + # 累积完整响应 + full_response = "" + last_chunk = None + chunk_count = 0 - # # 调用 LLM(流式,支持字符串或消息列表) - # async for chunk in llm.astream(prompt_or_messages): - # # 提取内容 - # if hasattr(chunk, 'content'): - # content = chunk.content - # else: - # content = str(chunk) + # 调用 LLM(流式,支持字符串或消息列表) + # 注意:astream 方法本身就是流式的,不需要额外配置 + async for chunk in llm.astream(prompt_or_messages): + # 提取内容 + if hasattr(chunk, 'content'): + content = chunk.content + else: + content = str(chunk) - # full_response += content - # last_chunk = chunk - # logger.info(f"节点 {self.node_id} LLM : {content}") - # # 流式返回每个文本片段 - # yield content + # 只有当内容不为空时才处理 + if content: + full_response += content + last_chunk = chunk + chunk_count += 1 + + # logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...") + # 流式返回每个文本片段 + yield content #AIMessage(content=content) - # logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - # # 构建完整的 AIMessage(包含元数据) - # if isinstance(last_chunk, AIMessage): - # final_message = AIMessage( - # content=full_response, - # response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - # ) - # else: - # final_message = AIMessage(content=full_response) + # 构建完整的 AIMessage(包含元数据) + if isinstance(last_chunk, AIMessage): + final_message = AIMessage( + content=full_response, + response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} + ) + else: + final_message = AIMessage(content=full_response) - # # yield 完成标记 - # yield {"__final__": True, "result": final_message} + # yield 完成标记 + yield {"__final__": True, "result": final_message}