[fix] model support stream

This commit is contained in:
Mark
2025-12-20 16:03:41 +08:00
parent 6c04c99073
commit d8fcea8564
5 changed files with 377 additions and 134 deletions

View File

@@ -125,17 +125,22 @@ class WorkflowExecutor:
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
async def node_func(state: WorkflowState, inst=node_instance):
async for item in inst.run_stream(state):
yield item
workflow.add_node(node_id, node_func)
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
async def node_func(state: WorkflowState, inst=node_instance):
return await inst.run(state)
workflow.add_node(node_id, node_func)
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
@@ -283,13 +288,9 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
使用 stream_mode="updates" 来获取每个节点的 state 更新。
节点的 generator 会 yield 多个值:
- 中间的 chunk 事件(带 type="chunk"
- 最后的 state 更新(纯字典,包含 node_outputs 等)
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
使用多个 stream_mode 来获取:
1. "updates" - 节点的 state 更新和流式 chunk
2. "debug" - 节点执行的详细信息(开始/完成时间
Args:
input_data: 输入数据
@@ -297,7 +298,7 @@ class WorkflowExecutor:
Yields:
流式事件
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
@@ -310,34 +311,73 @@ class WorkflowExecutor:
# 3. 执行工作流
try:
async for mode, event in graph.astream(
chunk_count = 0
async for event in graph.astream(
initial_state,
stream_mode=["updates","messages"],
stream_mode=["updates", "debug"],
):
# print("刚才跑的节点:", event[0])
# # 通过图结构就能算出“接下来是谁”
# print("接下来可能跑:", graph.get_next(event[0]))
# print("="*50)
# # print("mode",mode)
# print("event",event)
# print("="*50)
# event 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
for node_id, update in event.items():
print("="*50)
print("node_id",node_id)
print("update",update)
mode, data = event
print("="*50)
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
if mode == "debug":
# 处理调试信息(节点执行状态)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# 节点开始执行
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content")
"type": "node_start",
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
# 其他情况state 更新)会被 LangGraph 自动合并到 state不需要我们处理
print(event)
yield event
elif event_type == "task_result":
# 节点执行完成
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
yield {
"type": "node_end",
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
elif mode == "updates":
# 处理 state 更新
# data 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
print("="*50)
print(data)
print("-"*50)
for node_id, update in data.items():
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
chunk_count += 1
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content"),
"full_content": update.get("full_content")
}
else:
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
# 其他情况state 更新)会被 LangGraph 自动合并到 state
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
except Exception as e:
# 计算耗时(即使失败也记录)

View File

@@ -245,6 +245,9 @@ class BaseNode(ABC):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
# print("="*50)
# print(item)
# print("-"*50)
chunks.append(item)
yield {
"type": "chunk",

View File

@@ -30,11 +30,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
# pool = self.get_variable_pool(state)
# print("="*20)
# print( pool.get("start.test"))
# print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
@@ -46,7 +42,45 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
当 end 节点前面是 LLM 节点时,流式输出其内容。
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染
if output_template:
output = self._render_template(output_template, state)
# 检查输出中是否包含节点引用(如 {{llm_node.output}}
# 如果包含,则逐字符流式输出
if output:
# 逐字符流式输出
for char in output:
yield char
else:
output = "工作流已完成"
for char in output:
yield char
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
# yield 完成标记
yield {"__final__": True, "result": output}

View File

@@ -125,19 +125,22 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
print("="*50)
print("stream",stream)
print("="*50)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params={"streaming": stream}
extra_params=extra_params
),
type=model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -201,47 +204,54 @@ class LLMNode(BaseNode):
}
return None
# async def execute_stream(self, state: WorkflowState):
# """流式执行 LLM 调用
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
# Args:
# state: 工作流状态
Args:
state: 工作流状态
# Yields:
# 文本片段chunk或完成标记
# """
# llm, prompt_or_messages = self._prepare_llm(state,True)
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state, True)
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# # 累积完整响应
# full_response = ""
# last_chunk = None
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# # 调用 LLM流式支持字符串或消息列表
# async for chunk in llm.astream(prompt_or_messages):
# # 提取内容
# if hasattr(chunk, 'content'):
# content = chunk.content
# else:
# content = str(chunk)
# 调用 LLM流式支持字符串或消息列表
# 注意astream 方法本身就是流式的,不需要额外配置
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
# full_response += content
# last_chunk = chunk
# logger.info(f"节点 {self.node_id} LLM : {content}")
# # 流式返回每个文本片段
# yield content
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
# 流式返回每个文本片段
yield content #AIMessage(content=content)
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# # 构建完整的 AIMessage包含元数据
# if isinstance(last_chunk, AIMessage):
# final_message = AIMessage(
# content=full_response,
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
# )
# else:
# final_message = AIMessage(content=full_response)
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# # yield 完成标记
# yield {"__final__": True, "result": final_message}
# yield 完成标记
yield {"__final__": True, "result": final_message}