feat(workflow): add session context memory support to LLM nodes
This commit is contained in:
@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
|
|||||||
@router.post("/{app_id}/workflow")
|
@router.post("/{app_id}/workflow")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def create_workflow_config(
|
async def create_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
config: WorkflowConfigCreate,
|
config: WorkflowConfigCreate,
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""创建工作流配置
|
"""创建工作流配置
|
||||||
|
|
||||||
@@ -96,6 +96,7 @@ async def create_workflow_config(
|
|||||||
msg=f"创建工作流配置失败: {str(e)}"
|
msg=f"创建工作流配置失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# @router.get("/{app_id}/workflow")
|
# @router.get("/{app_id}/workflow")
|
||||||
# async def get_workflow_config(
|
# async def get_workflow_config(
|
||||||
@@ -199,10 +200,10 @@ async def create_workflow_config(
|
|||||||
|
|
||||||
@router.delete("/{app_id}/workflow")
|
@router.delete("/{app_id}/workflow")
|
||||||
async def delete_workflow_config(
|
async def delete_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""删除工作流配置
|
"""删除工作流配置
|
||||||
|
|
||||||
@@ -243,11 +244,11 @@ async def delete_workflow_config(
|
|||||||
|
|
||||||
@router.post("/{app_id}/workflow/validate")
|
@router.post("/{app_id}/workflow/validate")
|
||||||
async def validate_workflow_config(
|
async def validate_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||||
):
|
):
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
@@ -312,12 +313,12 @@ async def validate_workflow_config(
|
|||||||
|
|
||||||
@router.get("/{app_id}/workflow/executions")
|
@router.get("/{app_id}/workflow/executions")
|
||||||
async def get_workflow_executions(
|
async def get_workflow_executions(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
offset: Annotated[int, Query(ge=0)] = 0
|
offset: Annotated[int, Query(ge=0)] = 0
|
||||||
):
|
):
|
||||||
"""获取工作流执行记录列表
|
"""获取工作流执行记录列表
|
||||||
|
|
||||||
@@ -365,10 +366,10 @@ async def get_workflow_executions(
|
|||||||
|
|
||||||
@router.get("/workflow/executions/{execution_id}")
|
@router.get("/workflow/executions/{execution_id}")
|
||||||
async def get_workflow_execution(
|
async def get_workflow_execution(
|
||||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""获取工作流执行详情
|
"""获取工作流执行详情
|
||||||
|
|
||||||
@@ -417,16 +418,14 @@ async def get_workflow_execution(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
@router.post("/{app_id}/workflow/run")
|
@router.post("/{app_id}/workflow/run")
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
request: WorkflowExecutionRequest,
|
request: WorkflowExecutionRequest,
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""执行工作流
|
"""执行工作流
|
||||||
|
|
||||||
@@ -487,11 +486,11 @@ async def run_workflow(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async for event in await service.run_workflow(
|
async for event in await service.run_workflow(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
triggered_by=current_user.id,
|
triggered_by=current_user.id,
|
||||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
stream=True
|
stream=True
|
||||||
):
|
):
|
||||||
# 提取事件类型和数据
|
# 提取事件类型和数据
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
@@ -554,10 +553,10 @@ async def run_workflow(
|
|||||||
|
|
||||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||||
async def cancel_workflow_execution(
|
async def cancel_workflow_execution(
|
||||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""取消工作流执行
|
"""取消工作流执行
|
||||||
|
|
||||||
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
|
|||||||
|
|
||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||||
return fail(code=e.error_code, msg=e.message)
|
return fail(code=e.code, msg=e.message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||||
return fail(
|
return fail(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||||
# API Keys Configuration
|
# API Keys Configuration
|
||||||
@@ -142,7 +143,6 @@ class Settings:
|
|||||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
|
|
||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||||
@@ -150,7 +150,7 @@ class Settings:
|
|||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||||
@@ -168,6 +168,9 @@ class Settings:
|
|||||||
# official environment system version
|
# official environment system version
|
||||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||||
|
|
||||||
|
# workflow config
|
||||||
|
WORKFLOW_NODE_TIMEOUT: int = os.getenv("WORKFLOW_NODE_TIMEOUT", 600)
|
||||||
|
|
||||||
def get_memory_output_path(self, filename: str = "") -> str:
|
def get_memory_output_path(self, filename: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Get the full path for memory module output files.
|
Get the full path for memory module output files.
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class WorkflowExecutor:
|
|||||||
初始化的工作流状态
|
初始化的工作流状态
|
||||||
"""
|
"""
|
||||||
user_message = input_data.get("message") or ""
|
user_message = input_data.get("message") or ""
|
||||||
|
conversation_messages = input_data.get("conv_messages") or []
|
||||||
|
|
||||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||||
config_variables_list = self.workflow_config.get("variables") or []
|
config_variables_list = self.workflow_config.get("variables") or []
|
||||||
@@ -114,7 +115,7 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [('user', user_message)],
|
"messages": conversation_messages,
|
||||||
"variables": variables,
|
"variables": variables,
|
||||||
"node_outputs": {},
|
"node_outputs": {},
|
||||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||||
|
|||||||
@@ -7,13 +7,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from operator import add
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AnyMessage, AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
from typing_extensions import TypedDict, Annotated
|
from typing_extensions import TypedDict, Annotated
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
|||||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||||
"""
|
"""
|
||||||
# List of messages (append mode)
|
# List of messages (append mode)
|
||||||
messages: Annotated[list[tuple[str, str]], add]
|
messages: list[dict[str, str]]
|
||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
@@ -154,7 +154,7 @@ class BaseNode(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
超时时间
|
超时时间
|
||||||
"""
|
"""
|
||||||
return 60
|
return settings.WORKFLOW_NODE_TIMEOUT
|
||||||
# return self.error_handling.get("timeout", 60)
|
# return self.error_handling.get("timeout", 60)
|
||||||
|
|
||||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
|||||||
# 返回包装后的输出和运行时变量
|
# 返回包装后的输出和运行时变量
|
||||||
return {
|
return {
|
||||||
**wrapped_output,
|
**wrapped_output,
|
||||||
|
"messages": state["messages"],
|
||||||
"variables": state["variables"],
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
@@ -356,6 +357,7 @@ class BaseNode(ABC):
|
|||||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||||
state_update = {
|
state_update = {
|
||||||
**final_output,
|
**final_output,
|
||||||
|
"messages": state["messages"],
|
||||||
"variables": state["variables"],
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ End 节点实现
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
|
|||||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
if output_template:
|
if output_template:
|
||||||
output = self._render_template(output_template, state, strict=False)
|
output = self._render_template(output_template, state, strict=False)
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": output
|
||||||
|
}
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
])
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
# 统计信息(用于日志)
|
||||||
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
|
|||||||
"chunk_index": 1,
|
"chunk_index": 1,
|
||||||
"is_suffix": False
|
"is_suffix": False
|
||||||
})
|
})
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
}
|
||||||
|
])
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
|
|||||||
source_node_id = edge.get("source")
|
source_node_id = edge.get("source")
|
||||||
# Check if the source node is an LLM node
|
# Check if the source node is an LLM node
|
||||||
for node in self.workflow_config.get("nodes", []):
|
for node in self.workflow_config.get("nodes", []):
|
||||||
print("="*50)
|
|
||||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||||
direct_upstream_llm_nodes.append(source_node_id)
|
direct_upstream_llm_nodes.append(source_node_id)
|
||||||
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||||
|
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": output
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
# yield completion marker
|
# yield completion marker
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
logger.info(
|
||||||
|
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||||
|
|
||||||
# Collect suffix parts
|
# Collect suffix parts
|
||||||
suffix_parts = []
|
suffix_parts = []
|
||||||
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
|
|||||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||||
full_output = self._render_template(output_template, state, strict=False)
|
full_output = self._render_template(output_template, state, strict=False)
|
||||||
|
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_output
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||||
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||||
|
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
|
|||||||
"""消息配置"""
|
"""消息配置"""
|
||||||
|
|
||||||
role: str = Field(
|
role: str = Field(
|
||||||
...,
|
default='user',
|
||||||
description="消息角色:system, user, assistant"
|
description="消息角色:system, user, assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
content: str = Field(
|
content: str = Field(
|
||||||
...,
|
default="",
|
||||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
|
|||||||
return v.lower()
|
return v.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryWindowSetting(BaseModel):
|
||||||
|
enable: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="启用记忆"
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_window: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="启用记忆窗口"
|
||||||
|
)
|
||||||
|
|
||||||
|
window_size: int = Field(
|
||||||
|
default=20,
|
||||||
|
description="记忆窗口大小"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeConfig(BaseNodeConfig):
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
"""LLM 节点配置
|
"""LLM 节点配置
|
||||||
|
|
||||||
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="上下文"
|
description="上下文"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
memory: MemoryWindowSetting = Field(
|
||||||
|
...,
|
||||||
|
description="对话上下文窗口"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. 处理消息格式(优先使用 messages)
|
# 1. 处理消息格式(优先使用 messages)
|
||||||
messages_config = self.config.get("messages")
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.role.lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, state)
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append(SystemMessage(content=content))
|
messages.append({"role": "system", "content": content})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append({"role": "user", "content": content})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append(AIMessage(content=content))
|
messages.append({"role": "user", "content": content})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
if self.typed_config.memory.enable:
|
||||||
|
# if self.typed_config.memory.enable_window:
|
||||||
|
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||||
prompt_or_messages = messages
|
prompt_or_messages = messages
|
||||||
else:
|
else:
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
|
|||||||
return {
|
return {
|
||||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||||
for msg in prompt_or_messages
|
for msg in prompt_or_messages
|
||||||
] if isinstance(prompt_or_messages, list) else None,
|
] if isinstance(prompt_or_messages, list) else None,
|
||||||
"config": {
|
"config": {
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
|
|||||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||||
|
|
||||||
|
|
||||||
class ToolOldConfig(BaseModel):
|
class ToolOldConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
|
|||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunRequest(BaseModel):
|
class DraftRunRequest(BaseModel):
|
||||||
"""试运行请求"""
|
"""试运行请求"""
|
||||||
message: str = Field(..., description="用户消息")
|
message: str = Field(..., description="用户消息")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||||
|
from app.schemas import DraftRunRequest
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -59,7 +60,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.system_prompt
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
@@ -210,7 +211,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.system_prompt
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
@@ -511,7 +512,6 @@ class AppChatService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
except (GeneratorExit, asyncio.CancelledError):
|
except (GeneratorExit, asyncio.CancelledError):
|
||||||
# 生成器被关闭或任务被取消,正常退出
|
# 生成器被关闭或任务被取消,正常退出
|
||||||
logger.debug("多 Agent 流式聊天被中断")
|
logger.debug("多 Agent 流式聊天被中断")
|
||||||
@@ -537,83 +537,19 @@ class AppChatService:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
workflow_service = WorkflowService(self.db)
|
||||||
|
payload = DraftRunRequest(
|
||||||
input_data = {"message":message, "variables": variables,
|
message=message,
|
||||||
"conversation_id": str(conversation_id)}
|
variables=variables,
|
||||||
inconfig = workflow_service.get_workflow_config(app_id)
|
conversation_id=str(conversation_id),
|
||||||
|
stream=True,
|
||||||
# 2. 创建执行记录
|
user_id=user_id
|
||||||
execution = workflow_service.create_execution(
|
)
|
||||||
workflow_config_id=inconfig.id,
|
return await workflow_service.run(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
payload=payload,
|
||||||
triggered_by=None,
|
config=config,
|
||||||
conversation_id=conversation_id,
|
workspace_id=workspace_id,
|
||||||
input_data=input_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 构建工作流配置字典
|
|
||||||
workflow_config_dict = {
|
|
||||||
"nodes": config.nodes,
|
|
||||||
"edges": config.edges,
|
|
||||||
"variables": config.variables,
|
|
||||||
"execution_config": config.execution_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 4. 获取工作空间 ID(从 app 获取)
|
|
||||||
|
|
||||||
# 5. 执行工作流
|
|
||||||
from app.core.workflow.executor import execute_workflow
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新状态为运行中
|
|
||||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
|
||||||
|
|
||||||
result = await execute_workflow(
|
|
||||||
workflow_config=workflow_config_dict,
|
|
||||||
input_data=input_data,
|
|
||||||
execution_id=execution.execution_id,
|
|
||||||
workspace_id=str(workspace_id),
|
|
||||||
user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新执行结果
|
|
||||||
if result.get("status") == "completed":
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"completed",
|
|
||||||
output_data=result.get("node_outputs", {})
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=result.get("error")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回增强的响应结构
|
|
||||||
return {
|
|
||||||
"execution_id": execution.execution_id,
|
|
||||||
"status": result.get("status"),
|
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
|
||||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
|
||||||
"error_message": result.get("error"),
|
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
|
||||||
"token_usage": result.get("token_usage")
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=str(e)
|
|
||||||
)
|
|
||||||
raise BusinessException(
|
|
||||||
code=BizCode.INTERNAL_ERROR,
|
|
||||||
message=f"工作流执行失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def workflow_chat_stream(
|
async def workflow_chat_stream(
|
||||||
self,
|
self,
|
||||||
@@ -632,62 +568,21 @@ class AppChatService:
|
|||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
workflow_service = WorkflowService(self.db)
|
||||||
input_data = {"message": message, "variables": variables,
|
payload = DraftRunRequest(
|
||||||
"conversation_id": str(conversation_id)}
|
message=message,
|
||||||
inconfig = workflow_service.get_workflow_config(app_id)
|
variables=variables,
|
||||||
# 2. 创建执行记录
|
conversation_id=str(conversation_id),
|
||||||
execution = workflow_service.create_execution(
|
stream=True,
|
||||||
workflow_config_id=inconfig.id,
|
user_id=user_id
|
||||||
app_id=app_id,
|
|
||||||
trigger_type="manual",
|
|
||||||
triggered_by=None,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
input_data=input_data
|
|
||||||
)
|
)
|
||||||
|
async for event in workflow_service.run_stream(
|
||||||
|
app_id=app_id,
|
||||||
|
payload=payload,
|
||||||
|
config=config,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
# 3. 构建工作流配置字典
|
|
||||||
workflow_config_dict = {
|
|
||||||
"nodes": config.nodes,
|
|
||||||
"edges": config.edges,
|
|
||||||
"variables": config.variables,
|
|
||||||
"execution_config": config.execution_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 4. 获取工作空间 ID(从 app 获取)
|
|
||||||
|
|
||||||
# 5. 流式执行工作流
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新状态为运行中
|
|
||||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
|
||||||
|
|
||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
|
||||||
async for event in workflow_service._run_workflow_stream(
|
|
||||||
workflow_config=workflow_config_dict,
|
|
||||||
input_data=input_data,
|
|
||||||
execution_id=execution.execution_id,
|
|
||||||
workspace_id=str(workspace_id),
|
|
||||||
user_id=user_id
|
|
||||||
):
|
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
|
||||||
yield event
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=str(e)
|
|
||||||
)
|
|
||||||
# 发送错误事件
|
|
||||||
yield {
|
|
||||||
"event": "error",
|
|
||||||
"data": {
|
|
||||||
"execution_id": execution.execution_id,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,11 @@
|
|||||||
工作流服务层
|
工作流服务层
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
|
||||||
from typing import Any, Annotated, AsyncGenerator
|
from typing import Any, Annotated, AsyncGenerator
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories.conversation_repository import MessageRepository
|
||||||
|
from app.models.conversation_model import Message
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest
|
||||||
from app.utils.sse_utils import format_sse_message
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ class WorkflowService:
|
|||||||
self.config_repo = WorkflowConfigRepository(db)
|
self.config_repo = WorkflowConfigRepository(db)
|
||||||
self.execution_repo = WorkflowExecutionRepository(db)
|
self.execution_repo = WorkflowExecutionRepository(db)
|
||||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||||
|
self.message_repo = MessageRepository(db)
|
||||||
|
|
||||||
# ==================== 配置管理 ====================
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
@@ -418,14 +419,13 @@ class WorkflowService:
|
|||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
workspace_id:
|
||||||
|
config:
|
||||||
|
payload:
|
||||||
app_id: 应用 ID
|
app_id: 应用 ID
|
||||||
input_data: 输入数据(包含 message 和 variables)
|
|
||||||
triggered_by: 触发用户 ID
|
|
||||||
conversation_id: 会话 ID(可选)
|
|
||||||
stream: 是否流式返回
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行结果(非流式)或生成器(流式)
|
执行结果(非流式)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
BusinessException: 配置不存在或执行失败时抛出
|
BusinessException: 配置不存在或执行失败时抛出
|
||||||
@@ -438,7 +438,8 @@ class WorkflowService:
|
|||||||
code=BizCode.CONFIG_MISSING,
|
code=BizCode.CONFIG_MISSING,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
input_data = {"message": payload.message, "variables": payload.variables,
|
||||||
|
"conversation_id": payload.conversation_id}
|
||||||
|
|
||||||
# 转换 user_id 为 UUID
|
# 转换 user_id 为 UUID
|
||||||
triggered_by_uuid = None
|
triggered_by_uuid = None
|
||||||
@@ -461,7 +462,7 @@ class WorkflowService:
|
|||||||
workflow_config_id=config.id,
|
workflow_config_id=config.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
trigger_type="manual",
|
||||||
triggered_by=triggered_by_uuid,
|
triggered_by=None,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
input_data=input_data
|
input_data=input_data
|
||||||
)
|
)
|
||||||
@@ -500,8 +501,11 @@ class WorkflowService:
|
|||||||
variables = last_state.get("variables", {})
|
variables = last_state.get("variables", {})
|
||||||
conv_vars = variables.get("conv", {})
|
conv_vars = variables.get("conv", {})
|
||||||
input_data["conv"] = conv_vars
|
input_data["conv"] = conv_vars
|
||||||
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
|
|
||||||
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
@@ -517,6 +521,17 @@ class WorkflowService:
|
|||||||
"completed",
|
"completed",
|
||||||
output_data=result
|
output_data=result
|
||||||
)
|
)
|
||||||
|
final_messages = result.get("messages", [])[init_message_length:]
|
||||||
|
for message in final_messages:
|
||||||
|
message_obj = Message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role=message["role"],
|
||||||
|
content=message["content"],
|
||||||
|
)
|
||||||
|
self.message_repo.add_message(message_obj)
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Workflow Run Success, "
|
||||||
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
else:
|
else:
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
@@ -529,6 +544,7 @@ class WorkflowService:
|
|||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"status": result.get("status"),
|
"status": result.get("status"),
|
||||||
"variables": result.get("variables"),
|
"variables": result.get("variables"),
|
||||||
|
"messages": result.get("messages"),
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
@@ -559,6 +575,7 @@ class WorkflowService:
|
|||||||
"""运行工作流(流式)
|
"""运行工作流(流式)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
workspace_id:
|
||||||
app_id: 应用 ID
|
app_id: 应用 ID
|
||||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||||
config: 存储类型(可选)
|
config: 存储类型(可选)
|
||||||
@@ -601,7 +618,7 @@ class WorkflowService:
|
|||||||
workflow_config_id=config.id,
|
workflow_config_id=config.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
trigger_type="manual",
|
||||||
triggered_by=triggered_by_uuid,
|
triggered_by=None,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
input_data=input_data
|
input_data=input_data
|
||||||
)
|
)
|
||||||
@@ -638,17 +655,46 @@ class WorkflowService:
|
|||||||
variables = last_state.get("variables", {})
|
variables = last_state.get("variables", {})
|
||||||
conv_vars = variables.get("conv", {})
|
conv_vars = variables.get("conv", {})
|
||||||
input_data["conv"] = conv_vars
|
input_data["conv"] = conv_vars
|
||||||
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
from app.core.workflow.executor import execute_workflow_stream
|
||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
async for event in execute_workflow_stream(
|
||||||
async for event in self._run_workflow_stream(
|
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=end_user_id
|
user_id=end_user_id
|
||||||
):
|
):
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
if event.get("event") == "workflow_end":
|
||||||
|
|
||||||
|
status = event.get("data", {}).get("status")
|
||||||
|
if status == "completed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"completed",
|
||||||
|
output_data=event.get("data")
|
||||||
|
)
|
||||||
|
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||||
|
for message in final_messages:
|
||||||
|
message_obj = Message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role=message["role"],
|
||||||
|
content=message["content"],
|
||||||
|
)
|
||||||
|
self.message_repo.add_message(message_obj)
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Workflow Run Success, "
|
||||||
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
|
elif status == "failed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
output_data=event.get("data")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"unexpect workflow run status, status: {status}")
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -667,6 +713,8 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated(reason="This method is deprecated. "
|
||||||
|
"Please use WorkflowService.run / run_stream instead.")
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -819,6 +867,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
return clean_value(event)
|
return clean_value(event)
|
||||||
|
|
||||||
|
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||||
async def _run_workflow_stream(
|
async def _run_workflow_stream(
|
||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ dependencies = [
|
|||||||
"markdown-to-json==2.1.1",
|
"markdown-to-json==2.1.1",
|
||||||
"valkey==6.0.2",
|
"valkey==6.0.2",
|
||||||
"python-calamine>=0.4.0",
|
"python-calamine>=0.4.0",
|
||||||
"xlrd==2.0.2"
|
"xlrd==2.0.2",
|
||||||
|
"deprecated>=1.3.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
Reference in New Issue
Block a user