feat(workflow): add session context memory support to LLM nodes

This commit is contained in:
Eternity
2026-01-14 16:35:46 +08:00
parent b5a366ef5e
commit 567624c323
11 changed files with 249 additions and 228 deletions

View File

@@ -74,6 +74,7 @@ class WorkflowExecutor:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_messages = input_data.get("conv_messages") or []
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or []
@@ -114,7 +115,7 @@ class WorkflowExecutor:
}
return {
"messages": [('user', user_message)],
"messages": conversation_messages,
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)

View File

@@ -7,13 +7,13 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from operator import add
from typing import Any
from langchain_core.messages import AnyMessage, AIMessage
from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# List of messages (append mode)
messages: Annotated[list[tuple[str, str]], add]
messages: list[dict[str, str]]
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
@@ -154,7 +154,7 @@ class BaseNode(ABC):
Returns:
超时时间
"""
return 60
return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
@@ -356,6 +357,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var

View File

@@ -6,7 +6,6 @@ End 节点实现
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成"
# 统计信息(用于日志)
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
"chunk_index": 1,
"is_suffix": False
})
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output}
return
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
source_node_id = edge.get("source")
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker
yield {"__final__": True, "result": output}
return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts
suffix_parts = []
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})

View File

@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
default='user',
description="消息角色system, user, assistant"
)
content: str = Field(
...,
default="",
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
return v.lower()
class MemoryWindowSetting(BaseModel):
enable: bool = Field(
default=False,
description="启用记忆"
)
enable_window: bool = Field(
default=False,
description="启用记忆窗口"
)
window_size: int = Field(
default=20,
description="记忆窗口大小"
)
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="上下文"
)
memory: MemoryWindowSetting = Field(
...,
description="对话上下文窗口"
)
# 简单模式
prompt: str | None = Field(
default=None,

View File

@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
messages_config = self.typed_config.messages
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
messages.append({"role": "system", "content": content})
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
messages.append({"role": "user", "content": content})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
{"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {