[fix] model support stream
This commit is contained in:
@@ -125,17 +125,22 @@ class WorkflowExecutor:
|
||||
if stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
workflow.add_node(node_id, node_func)
|
||||
def make_stream_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
return await inst.run(state)
|
||||
workflow.add_node(node_id, node_func)
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_func(node_instance))
|
||||
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
@@ -283,13 +288,9 @@ class WorkflowExecutor:
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
使用 stream_mode="updates" 来获取每个节点的 state 更新。
|
||||
节点的 generator 会 yield 多个值:
|
||||
- 中间的 chunk 事件(带 type="chunk")
|
||||
- 最后的 state 更新(纯字典,包含 node_outputs 等)
|
||||
|
||||
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
|
||||
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
|
||||
使用多个 stream_mode 来获取:
|
||||
1. "updates" - 节点的 state 更新和流式 chunk
|
||||
2. "debug" - 节点执行的详细信息(开始/完成时间)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
@@ -297,7 +298,7 @@ class WorkflowExecutor:
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
@@ -310,34 +311,73 @@ class WorkflowExecutor:
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
async for mode, event in graph.astream(
|
||||
chunk_count = 0
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates","messages"],
|
||||
stream_mode=["updates", "debug"],
|
||||
):
|
||||
# print("刚才跑的节点:", event[0])
|
||||
# # 通过图结构就能算出“接下来是谁”
|
||||
# print("接下来可能跑:", graph.get_next(event[0]))
|
||||
# print("="*50)
|
||||
# # print("mode",mode)
|
||||
# print("event",event)
|
||||
# print("="*50)
|
||||
# event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||
for node_id, update in event.items():
|
||||
print("="*50)
|
||||
print("node_id",node_id)
|
||||
print("update",update)
|
||||
mode, data = event
|
||||
|
||||
print("="*50)
|
||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||
# 这是流式 chunk,转发给客户端
|
||||
if mode == "debug":
|
||||
# 处理调试信息(节点执行状态)
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if event_type == "task":
|
||||
# 节点开始执行
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": update.get("node_id"),
|
||||
"chunk": update.get("content")
|
||||
"type": "node_start",
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理
|
||||
print(event)
|
||||
yield event
|
||||
elif event_type == "task_result":
|
||||
# 节点执行完成
|
||||
result = payload.get("result", {})
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# 处理 state 更新
|
||||
# data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||
print("="*50)
|
||||
print(data)
|
||||
print("-"*50)
|
||||
for node_id, update in data.items():
|
||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||
# 这是流式 chunk,转发给客户端
|
||||
chunk_count += 1
|
||||
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": update.get("node_id"),
|
||||
"chunk": update.get("content"),
|
||||
"full_content": update.get("full_content")
|
||||
}
|
||||
else:
|
||||
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
|
||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state
|
||||
|
||||
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
|
||||
@@ -245,6 +245,9 @@ class BaseNode(ABC):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
# print("="*50)
|
||||
# print(item)
|
||||
# print("-"*50)
|
||||
chunks.append(item)
|
||||
yield {
|
||||
"type": "chunk",
|
||||
|
||||
@@ -30,11 +30,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
# pool = self.get_variable_pool(state)
|
||||
|
||||
# print("="*20)
|
||||
# print( pool.get("start.test"))
|
||||
# print("="*20)
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -46,7 +42,45 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
print("="*20)
|
||||
print(output)
|
||||
print("="*20)
|
||||
|
||||
return output
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
当 end 节点前面是 LLM 节点时,流式输出其内容。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
|
||||
# 检查输出中是否包含节点引用(如 {{llm_node.output}})
|
||||
# 如果包含,则逐字符流式输出
|
||||
if output:
|
||||
# 逐字符流式输出
|
||||
for char in output:
|
||||
yield char
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
for char in output:
|
||||
yield char
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
|
||||
@@ -125,19 +125,22 @@ class LLMNode(BaseNode):
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
print("="*50)
|
||||
print("stream",stream)
|
||||
print("="*50)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params={"streaming": stream}
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
@@ -201,47 +204,54 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
# async def execute_stream(self, state: WorkflowState):
|
||||
# """流式执行 LLM 调用
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
# Args:
|
||||
# state: 工作流状态
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
# Yields:
|
||||
# 文本片段(chunk)或完成标记
|
||||
# """
|
||||
# llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# # 累积完整响应
|
||||
# full_response = ""
|
||||
# last_chunk = None
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# # 调用 LLM(流式,支持字符串或消息列表)
|
||||
# async for chunk in llm.astream(prompt_or_messages):
|
||||
# # 提取内容
|
||||
# if hasattr(chunk, 'content'):
|
||||
# content = chunk.content
|
||||
# else:
|
||||
# content = str(chunk)
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
# 注意:astream 方法本身就是流式的,不需要额外配置
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
# full_response += content
|
||||
# last_chunk = chunk
|
||||
# logger.info(f"节点 {self.node_id} LLM : {content}")
|
||||
# # 流式返回每个文本片段
|
||||
# yield content
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
|
||||
# 流式返回每个文本片段
|
||||
yield content #AIMessage(content=content)
|
||||
|
||||
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# # 构建完整的 AIMessage(包含元数据)
|
||||
# if isinstance(last_chunk, AIMessage):
|
||||
# final_message = AIMessage(
|
||||
# content=full_response,
|
||||
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
# )
|
||||
# else:
|
||||
# final_message = AIMessage(content=full_response)
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
# # yield 完成标记
|
||||
# yield {"__final__": True, "result": final_message}
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
|
||||
Reference in New Issue
Block a user