feat(workflow): add session context memory support to LLM nodes
This commit is contained in:
@@ -7,17 +7,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
|
||||
# Database configuration (Postgres)
|
||||
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||
@@ -37,7 +38,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||
@@ -48,7 +49,7 @@ class Settings:
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||
|
||||
|
||||
# Xinference configuration
|
||||
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||
|
||||
@@ -57,17 +58,17 @@ class Settings:
|
||||
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||
|
||||
|
||||
# LLM Request Configuration
|
||||
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||
|
||||
|
||||
# JWT Token Configuration
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
@@ -86,19 +87,19 @@ class Settings:
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
# Superuser settings (internal defaults)
|
||||
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||
|
||||
|
||||
# Generic File Upload (internal)
|
||||
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||
@@ -123,7 +124,7 @@ class Settings:
|
||||
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||
|
||||
|
||||
# Sensitive Data Filtering
|
||||
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||
|
||||
@@ -142,7 +143,6 @@ class Settings:
|
||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
@@ -150,15 +150,15 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
@@ -167,7 +167,10 @@ class Settings:
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = os.getenv("WORKFLOW_NODE_TIMEOUT", 600)
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
@@ -182,7 +185,7 @@ class Settings:
|
||||
if filename:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -74,6 +74,7 @@ class WorkflowExecutor:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
@@ -114,7 +115,7 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [('user', user_message)],
|
||||
"messages": conversation_messages,
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import add
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AnyMessage, AIMessage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[tuple[str, str]], add]
|
||||
messages: list[dict[str, str]]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
@@ -154,7 +154,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
超时时间
|
||||
"""
|
||||
return 60
|
||||
return settings.WORKFLOW_NODE_TIMEOUT
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
@@ -356,6 +357,7 @@ class BaseNode(ABC):
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
|
||||
@@ -6,7 +6,6 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
else:
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
])
|
||||
output = "工作流已完成"
|
||||
|
||||
# 统计信息(用于日志)
|
||||
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
}
|
||||
])
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
|
||||
source_node_id = edge.get("source")
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
print("="*50)
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
logger.info(
|
||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_output
|
||||
}
|
||||
])
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
|
||||
@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
default='user',
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
default="",
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
|
||||
return v.lower()
|
||||
|
||||
|
||||
class MemoryWindowSetting(BaseModel):
|
||||
enable: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆"
|
||||
)
|
||||
|
||||
enable_window: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆窗口"
|
||||
)
|
||||
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
description="记忆窗口大小"
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
memory: MemoryWindowSetting = Field(
|
||||
...,
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
messages.append({"role": "system", "content": content})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
# if self.typed_config.memory.enable_window:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
"config": {
|
||||
|
||||
Reference in New Issue
Block a user