feat(workflow): add session context memory support to LLM nodes

This commit is contained in:
Eternity
2026-01-14 16:35:46 +08:00
parent b5a366ef5e
commit 567624c323
11 changed files with 249 additions and 228 deletions

View File

@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
@router.post("/{app_id}/workflow") @router.post("/{app_id}/workflow")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def create_workflow_config( async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate, config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""创建工作流配置 """创建工作流配置
@@ -96,6 +96,7 @@ async def create_workflow_config(
msg=f"创建工作流配置失败: {str(e)}" msg=f"创建工作流配置失败: {str(e)}"
) )
# #
# @router.get("/{app_id}/workflow") # @router.get("/{app_id}/workflow")
# async def get_workflow_config( # async def get_workflow_config(
@@ -199,10 +200,10 @@ async def create_workflow_config(
@router.delete("/{app_id}/workflow") @router.delete("/{app_id}/workflow")
async def delete_workflow_config( async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""删除工作流配置 """删除工作流配置
@@ -243,11 +244,11 @@ async def delete_workflow_config(
@router.post("/{app_id}/workflow/validate") @router.post("/{app_id}/workflow/validate")
async def validate_workflow_config( async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)], service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
): ):
"""验证工作流配置 """验证工作流配置
@@ -312,12 +313,12 @@ async def validate_workflow_config(
@router.get("/{app_id}/workflow/executions") @router.get("/{app_id}/workflow/executions")
async def get_workflow_executions( async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)], service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50, limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0 offset: Annotated[int, Query(ge=0)] = 0
): ):
"""获取工作流执行记录列表 """获取工作流执行记录列表
@@ -365,10 +366,10 @@ async def get_workflow_executions(
@router.get("/workflow/executions/{execution_id}") @router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution( async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")], execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""获取工作流执行详情 """获取工作流执行详情
@@ -417,16 +418,14 @@ async def get_workflow_execution(
) )
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run") @router.post("/{app_id}/workflow/run")
async def run_workflow( async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest, request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""执行工作流 """执行工作流
@@ -487,11 +486,11 @@ async def run_workflow(
""" """
try: try:
async for event in await service.run_workflow( async for event in await service.run_workflow(
app_id=app_id, app_id=app_id,
input_data=input_data, input_data=input_data,
triggered_by=current_user.id, triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True stream=True
): ):
# 提取事件类型和数据 # 提取事件类型和数据
event_type = event.get("event", "message") event_type = event.get("event", "message")
@@ -554,10 +553,10 @@ async def run_workflow(
@router.post("/workflow/executions/{execution_id}/cancel") @router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution( async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")], execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""取消工作流执行 """取消工作流执行
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
except BusinessException as e: except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}") logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message) return fail(code=e.code, msg=e.message)
except Exception as e: except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True) logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail( return fail(

View File

@@ -7,6 +7,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
class Settings: class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration # API Keys Configuration
@@ -142,7 +143,6 @@ class Settings:
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal) # Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
@@ -150,7 +150,7 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration # Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
@@ -168,6 +168,9 @@ class Settings:
# official environment system version # official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = os.getenv("WORKFLOW_NODE_TIMEOUT", 600)
def get_memory_output_path(self, filename: str = "") -> str: def get_memory_output_path(self, filename: str = "") -> str:
""" """
Get the full path for memory module output files. Get the full path for memory module output files.

View File

@@ -74,6 +74,7 @@ class WorkflowExecutor:
初始化的工作流状态 初始化的工作流状态
""" """
user_message = input_data.get("message") or "" user_message = input_data.get("message") or ""
conversation_messages = input_data.get("conv_messages") or []
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value # 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or [] config_variables_list = self.workflow_config.get("variables") or []
@@ -114,7 +115,7 @@ class WorkflowExecutor:
} }
return { return {
"messages": [('user', user_message)], "messages": conversation_messages,
"variables": variables, "variables": variables,
"node_outputs": {}, "node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)

View File

@@ -7,13 +7,13 @@
import asyncio import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from operator import add
from typing import Any from typing import Any
from langchain_core.messages import AnyMessage, AIMessage from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
""" """
# List of messages (append mode) # List of messages (append mode)
messages: Annotated[list[tuple[str, str]], add] messages: list[dict[str, str]]
# Set of loop node IDs, used for assigning values in loop nodes # Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list cycle_nodes: list
@@ -154,7 +154,7 @@ class BaseNode(ABC):
Returns: Returns:
超时时间 超时时间
""" """
return 60 return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60) # return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]: async def run(self, state: WorkflowState) -> dict[str, Any]:
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量 # 返回包装后的输出和运行时变量
return { return {
**wrapped_output, **wrapped_output,
"messages": state["messages"],
"variables": state["variables"], "variables": state["variables"],
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var
@@ -356,6 +357,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = { state_update = {
**final_output, **final_output,
"messages": state["messages"],
"variables": state["variables"], "variables": state["variables"],
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var

View File

@@ -6,7 +6,6 @@ End 节点实现
import logging import logging
import re import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state, strict=False) output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else: else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成" output = "工作流已完成"
# 统计信息(用于日志) # 统计信息(用于日志)
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
"chunk_index": 1, "chunk_index": 1,
"is_suffix": False "is_suffix": False
}) })
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
source_node_id = edge.get("source") source_node_id = edge.get("source")
# Check if the source node is an LLM node # Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []): for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id) direct_upstream_llm_nodes.append(source_node_id)
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker # yield completion marker
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix) # Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts # Collect suffix parts
suffix_parts = [] suffix_parts = []
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False) full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}") logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else: else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})

View File

@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
"""消息配置""" """消息配置"""
role: str = Field( role: str = Field(
..., default='user',
description="消息角色system, user, assistant" description="消息角色system, user, assistant"
) )
content: str = Field( content: str = Field(
..., default="",
description="消息内容,支持模板变量,如:{{ sys.message }}" description="消息内容,支持模板变量,如:{{ sys.message }}"
) )
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
return v.lower() return v.lower()
class MemoryWindowSetting(BaseModel):
enable: bool = Field(
default=False,
description="启用记忆"
)
enable_window: bool = Field(
default=False,
description="启用记忆窗口"
)
window_size: int = Field(
default=20,
description="记忆窗口大小"
)
class LLMNodeConfig(BaseNodeConfig): class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置 """LLM 节点配置
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="上下文" description="上下文"
) )
memory: MemoryWindowSetting = Field(
...,
description="对话上下文窗口"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,

View File

@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
""" """
# 1. 处理消息格式(优先使用 messages # 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages") messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
for msg_config in messages_config: for msg_config in messages_config:
role = msg_config.get("role", "user").lower() role = msg_config.role.lower()
content_template = msg_config.get("content", "") content_template = msg_config.content
content_template = self._render_context(content_template, state) content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state) content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append(SystemMessage(content=content)) messages.append({"role": "system", "content": content})
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append(HumanMessage(content=content)) messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content)) messages.append({"role": "user", "content": content})
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content)) messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages prompt_or_messages = messages
else: else:
# 使用简单的 prompt 格式(向后兼容) # 使用简单的 prompt 格式(向后兼容)
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
return { return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [ "messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} {"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None, ] if isinstance(prompt_or_messages, list) else None,
"config": { "config": {

View File

@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID") tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置") operation: Optional[str] = Field(default=None, description="工具特定配置")
class ToolOldConfig(BaseModel): class ToolOldConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回") stream: bool = Field(default=False, description="是否流式返回")
class DraftRunRequest(BaseModel): class DraftRunRequest(BaseModel):
"""试运行请求""" """试运行请求"""
message: str = Field(..., description="用户消息") message: str = Field(..., description="用户消息")

View File

@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.db import get_db from app.db import get_db
@@ -59,7 +60,7 @@ class AppChatService:
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt system_prompt = config.system_prompt
if variables: if variables:
@@ -210,7 +211,7 @@ class AppChatService:
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt system_prompt = config.system_prompt
if variables: if variables:
@@ -511,7 +512,6 @@ class AppChatService:
} }
) )
except (GeneratorExit, asyncio.CancelledError): except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出 # 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断") logger.debug("多 Agent 流式聊天被中断")
@@ -537,83 +537,19 @@ class AppChatService:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
workflow_service = WorkflowService(self.db) workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
input_data = {"message":message, "variables": variables, message=message,
"conversation_id": str(conversation_id)} variables=variables,
inconfig = workflow_service.get_workflow_config(app_id) conversation_id=str(conversation_id),
stream=True,
# 2. 创建执行记录 user_id=user_id
execution = workflow_service.create_execution( )
workflow_config_id=inconfig.id, return await workflow_service.run(
app_id=app_id, app_id=app_id,
trigger_type="manual", payload=payload,
triggered_by=None, config=config,
conversation_id=conversation_id, workspace_id=workspace_id,
input_data=input_data
) )
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
)
# 更新执行结果
if result.get("status") == "completed":
workflow_service.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def workflow_chat_stream( async def workflow_chat_stream(
self, self,
@@ -632,62 +568,21 @@ class AppChatService:
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
workflow_service = WorkflowService(self.db) workflow_service = WorkflowService(self.db)
input_data = {"message": message, "variables": variables, payload = DraftRunRequest(
"conversation_id": str(conversation_id)} message=message,
inconfig = workflow_service.get_workflow_config(app_id) variables=variables,
# 2. 创建执行记录 conversation_id=str(conversation_id),
execution = workflow_service.create_execution( stream=True,
workflow_config_id=inconfig.id, user_id=user_id
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
) )
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
):
yield event
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 流式执行工作流
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in workflow_service._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================

View File

@@ -2,12 +2,11 @@
工作流服务层 工作流服务层
""" """
import datetime import datetime
import json
import logging import logging
import uuid import uuid
import datetime
from typing import Any, Annotated, AsyncGenerator from typing import Any, Annotated, AsyncGenerator
from deprecated import deprecated
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.db import get_db, get_db_context from app.db import get_db, get_db_context
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.conversation_repository import MessageRepository
from app.models.conversation_model import Message
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.services.multi_agent_service import convert_uuids_to_str
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
WorkflowExecutionRepository, WorkflowExecutionRepository,
WorkflowNodeExecutionRepository WorkflowNodeExecutionRepository
) )
from app.schemas import DraftRunRequest from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message from app.services.multi_agent_service import convert_uuids_to_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,6 +37,7 @@ class WorkflowService:
self.config_repo = WorkflowConfigRepository(db) self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db) self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.message_repo = MessageRepository(db)
# ==================== 配置管理 ==================== # ==================== 配置管理 ====================
@@ -418,14 +419,13 @@ class WorkflowService:
"""运行工作流 """运行工作流
Args: Args:
workspace_id:
config:
payload:
app_id: 应用 ID app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns: Returns:
执行结果(非流式)或生成器(流式) 执行结果(非流式)
Raises: Raises:
BusinessException: 配置不存在或执行失败时抛出 BusinessException: 配置不存在或执行失败时抛出
@@ -438,7 +438,8 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING, code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}" message=f"工作流配置不存在: app_id={app_id}"
) )
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID # 转换 user_id 为 UUID
triggered_by_uuid = None triggered_by_uuid = None
@@ -461,7 +462,7 @@ class WorkflowService:
workflow_config_id=config.id, workflow_config_id=config.id,
app_id=app_id, app_id=app_id,
trigger_type="manual", trigger_type="manual",
triggered_by=triggered_by_uuid, triggered_by=None,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
input_data=input_data input_data=input_data
) )
@@ -500,8 +501,11 @@ class WorkflowService:
variables = last_state.get("variables", {}) variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {}) conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break break
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow( result = await execute_workflow(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
@@ -517,6 +521,17 @@ class WorkflowService:
"completed", "completed",
output_data=result output_data=result
) )
final_messages = result.get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
else: else:
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
@@ -529,6 +544,7 @@ class WorkflowService:
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"status": result.get("status"), "status": result.get("status"),
"variables": result.get("variables"), "variables": result.get("variables"),
"messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串) "output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
@@ -559,6 +575,7 @@ class WorkflowService:
"""运行工作流(流式) """运行工作流(流式)
Args: Args:
workspace_id:
app_id: 应用 ID app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等) payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选) config: 存储类型(可选)
@@ -601,7 +618,7 @@ class WorkflowService:
workflow_config_id=config.id, workflow_config_id=config.id,
app_id=app_id, app_id=app_id,
trigger_type="manual", trigger_type="manual",
triggered_by=triggered_by_uuid, triggered_by=None,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
input_data=input_data input_data=input_data
) )
@@ -638,17 +655,46 @@ class WorkflowService:
variables = last_state.get("variables", {}) variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {}) conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件) async for event in execute_workflow_stream(
async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=end_user_id user_id=end_user_id
): ):
# 直接转发 executor 的事件(已经是正确的格式) if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=event.get("data")
)
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed":
self.update_execution_status(
execution.execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event yield event
except Exception as e: except Exception as e:
@@ -667,6 +713,8 @@ class WorkflowService:
} }
} }
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow( async def run_workflow(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
@@ -819,6 +867,7 @@ class WorkflowService:
return clean_value(event) return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream( async def _run_workflow_stream(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],

View File

@@ -136,7 +136,8 @@ dependencies = [
"markdown-to-json==2.1.1", "markdown-to-json==2.1.1",
"valkey==6.0.2", "valkey==6.0.2",
"python-calamine>=0.4.0", "python-calamine>=0.4.0",
"xlrd==2.0.2" "xlrd==2.0.2",
"deprecated>=1.3.1",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]