Merge #104 into develop from feature/20251219_myh
feat(workflow): add support for question classifier in graph construction
* feature/20251219_myh: (11 commits)
feat(workflow): support variable types(TODO)
fix(workflow): fix passing of loop variable termination condition
feat(workflow): add support for passing workspace ID
feat(workflow): support retrieving variables wrapped in {{}} from variable pool
feat(prompt_opt): support streaming output for prompt optimization API
feat(workflow): update workflow conditional logic
feat(workflow): enable front-end to cover pre-rendered non-variable values
fix(workflow): ensure default values are properly retrieved in HTTP nodes
refactor(workflow): refactor graph construction to support subgraph building
Merge branch 'develop' into feature/20251219_myh
feat(workflow): add support for question classifier in graph construction
Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/104
This commit is contained in:
@@ -597,7 +597,8 @@ async def draft_run(
|
|||||||
async for event in workflow_service.run_stream(
|
async for event in workflow_service.run_stream(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
config=config
|
config=config,
|
||||||
|
workspace_id=current_user.current_workspace_id
|
||||||
):
|
):
|
||||||
# 提取事件类型和数据
|
# 提取事件类型和数据
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
@@ -627,7 +628,7 @@ async def draft_run(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await workflow_service.run(app_id, payload,config)
|
result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"工作流试运行返回结果",
|
"工作流试运行返回结果",
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Path
|
from fastapi import APIRouter, Depends, Path
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -104,35 +106,25 @@ async def get_prompt_opt(
|
|||||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||||
"""
|
"""
|
||||||
service = PromptOptimizerService(db)
|
service = PromptOptimizerService(db)
|
||||||
service.create_message(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
role=RoleType.USER,
|
|
||||||
content=data.message
|
|
||||||
)
|
|
||||||
opt_result = await service.optimize_prompt(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
model_id=data.model_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
current_prompt=data.current_prompt,
|
|
||||||
user_require=data.message
|
|
||||||
)
|
|
||||||
service.create_message(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
role=RoleType.ASSISTANT,
|
|
||||||
content=opt_result.desc
|
|
||||||
)
|
|
||||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
|
||||||
result = {
|
|
||||||
"prompt": opt_result.prompt,
|
|
||||||
"desc": opt_result.desc,
|
|
||||||
"variables": variables
|
|
||||||
}
|
|
||||||
result_schema = OptimizePromptResponse.model_validate(result)
|
|
||||||
return success(data=result_schema)
|
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
async for chunk in service.optimize_prompt(
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
model_id=data.model_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_prompt=data.current_prompt,
|
||||||
|
user_require=data.message
|
||||||
|
):
|
||||||
|
# chunk 是 prompt 的增量内容
|
||||||
|
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.graph_builder import GraphBuilder
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
# from app.core.tools.registry import ToolRegistry
|
# from app.core.tools.registry import ToolRegistry
|
||||||
@@ -191,155 +190,10 @@ class WorkflowExecutor:
|
|||||||
编译后的状态图
|
编译后的状态图
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||||
|
graph = GraphBuilder(
|
||||||
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
self.workflow_config,
|
||||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
stream=stream,
|
||||||
|
).build()
|
||||||
# 1. 创建状态图
|
|
||||||
workflow = StateGraph(WorkflowState)
|
|
||||||
|
|
||||||
# 2. 添加所有节点(包括 start 和 end)
|
|
||||||
start_node_id = None
|
|
||||||
end_node_ids = []
|
|
||||||
|
|
||||||
for node in self.nodes:
|
|
||||||
node_type = node.get("type")
|
|
||||||
node_id = node.get("id")
|
|
||||||
cycle_node = node.get("cycle")
|
|
||||||
if cycle_node:
|
|
||||||
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 记录 start 和 end 节点 ID
|
|
||||||
if node_type == NodeType.START:
|
|
||||||
start_node_id = node_id
|
|
||||||
elif node_type == NodeType.END:
|
|
||||||
end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
|
||||||
|
|
||||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
|
|
||||||
|
|
||||||
# Find all edges whose source is the current node
|
|
||||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
|
||||||
|
|
||||||
# Iterate over each branch
|
|
||||||
for idx in range(len(related_edge)):
|
|
||||||
# Generate a condition expression for each edge
|
|
||||||
# Used later to determine which branch to take based on the node's output
|
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
|
||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
|
||||||
|
|
||||||
if node_instance:
|
|
||||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
|
||||||
if stream and node_id in end_prefixes:
|
|
||||||
# 将 End 前缀配置注入到节点实例
|
|
||||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
|
||||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
|
||||||
|
|
||||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
|
||||||
if stream:
|
|
||||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
|
||||||
if node_id in adjacent_and_referenced:
|
|
||||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
|
||||||
|
|
||||||
# 包装节点的 run 方法
|
|
||||||
# 使用函数工厂避免闭包问题
|
|
||||||
if stream:
|
|
||||||
# 流式模式:创建 async generator 函数
|
|
||||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
|
||||||
def make_stream_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
|
||||||
async for item in inst.run_stream(state):
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
workflow.add_node(node_id, make_stream_func(node_instance))
|
|
||||||
else:
|
|
||||||
# 非流式模式:创建 async function
|
|
||||||
def make_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
return await inst.run(state)
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
workflow.add_node(node_id, make_func(node_instance))
|
|
||||||
|
|
||||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
|
||||||
|
|
||||||
# 3. 添加边
|
|
||||||
# 从 START 连接到 start 节点
|
|
||||||
if start_node_id:
|
|
||||||
workflow.add_edge(START, start_node_id)
|
|
||||||
logger.debug(f"添加边: START -> {start_node_id}")
|
|
||||||
|
|
||||||
for edge in self.workflow_config.get("edges", []):
|
|
||||||
source = edge.get("source")
|
|
||||||
target = edge.get("target")
|
|
||||||
edge_type = edge.get("type")
|
|
||||||
condition = edge.get("condition")
|
|
||||||
|
|
||||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
|
||||||
if source == start_node_id:
|
|
||||||
# 但要连接 start 到下一个节点
|
|
||||||
workflow.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# # 处理到 end 节点的边
|
|
||||||
# if target in end_node_ids:
|
|
||||||
# # 连接到 end 节点
|
|
||||||
# workflow.add_edge(source, target)
|
|
||||||
# logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# 跳过错误边(在节点内部处理)
|
|
||||||
if edge_type == "error":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if condition:
|
|
||||||
# 条件边
|
|
||||||
def make_router(cond, tgt):
|
|
||||||
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
|
||||||
|
|
||||||
|
|
||||||
def router_fn(state: WorkflowState):
|
|
||||||
if evaluate_condition(
|
|
||||||
cond,
|
|
||||||
state.get("variables", {}),
|
|
||||||
state.get("node_outputs", {}),
|
|
||||||
{
|
|
||||||
"execution_id": state.get("execution_id"),
|
|
||||||
"workspace_id": state.get("workspace_id"),
|
|
||||||
"user_id": state.get("user_id")
|
|
||||||
}
|
|
||||||
):
|
|
||||||
return tgt
|
|
||||||
return END
|
|
||||||
|
|
||||||
# 动态修改函数名,避免重复
|
|
||||||
router_fn.__name__ = f"router_{tgt}"
|
|
||||||
return router_fn
|
|
||||||
|
|
||||||
router_fn = make_router(condition, target)
|
|
||||||
workflow.add_conditional_edges(source, router_fn)
|
|
||||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
|
||||||
else:
|
|
||||||
# 普通边
|
|
||||||
workflow.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
|
|
||||||
# 从 end 节点连接到 END
|
|
||||||
for end_node_id in end_node_ids:
|
|
||||||
workflow.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"添加边: {end_node_id} -> END")
|
|
||||||
|
|
||||||
# 4. 编译图
|
|
||||||
graph = workflow.compile()
|
|
||||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|||||||
253
api/app/core/workflow/graph_builder.py
Normal file
253
api/app/core/workflow/graph_builder.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||||
|
from langgraph.graph import START, END
|
||||||
|
|
||||||
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: 子图拆解支持
|
||||||
|
class GraphBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
stream: bool = False,
|
||||||
|
subgraph: bool = False,
|
||||||
|
):
|
||||||
|
self.workflow_config = workflow_config
|
||||||
|
|
||||||
|
self.stream = stream
|
||||||
|
self.subgraph = subgraph
|
||||||
|
|
||||||
|
self.start_node_id = None
|
||||||
|
self.end_node_ids = []
|
||||||
|
|
||||||
|
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
|
return self.workflow_config.get("nodes", [])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edges(self) -> list[dict[str, Any]]:
|
||||||
|
return self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||||
|
"""分析 End 节点的前缀配置
|
||||||
|
|
||||||
|
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||||
|
提取该引用之前的前缀部分。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
prefixes = {}
|
||||||
|
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||||
|
|
||||||
|
# 找到所有 End 节点
|
||||||
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
|
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||||
|
|
||||||
|
for end_node in end_nodes:
|
||||||
|
end_node_id = end_node.get("id")
|
||||||
|
output_template = end_node.get("config", {}).get("output")
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||||
|
|
||||||
|
if not output_template:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 查找模板中引用了哪些节点
|
||||||
|
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||||
|
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
matches = list(re.finditer(pattern, output_template))
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||||
|
|
||||||
|
# 找到所有直接连接到 End 节点的上游节点
|
||||||
|
direct_upstream_nodes = []
|
||||||
|
for edge in self.edges:
|
||||||
|
if edge.get("target") == end_node_id:
|
||||||
|
source_node_id = edge.get("source")
|
||||||
|
direct_upstream_nodes.append(source_node_id)
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||||
|
|
||||||
|
# 找到第一个直接上游节点的引用
|
||||||
|
for match in matches:
|
||||||
|
referenced_node_id = match.group(1)
|
||||||
|
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||||
|
|
||||||
|
if referenced_node_id in direct_upstream_nodes:
|
||||||
|
# 这是直接上游节点的引用,提取前缀
|
||||||
|
prefix = output_template[:match.start()]
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||||
|
|
||||||
|
# 标记这个节点为"相邻且被引用"
|
||||||
|
adjacent_and_referenced.add(referenced_node_id)
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
prefixes[referenced_node_id] = prefix
|
||||||
|
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||||
|
|
||||||
|
# 只处理第一个直接上游节点的引用
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||||
|
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||||
|
return prefixes, adjacent_and_referenced
|
||||||
|
|
||||||
|
def add_nodes(self):
|
||||||
|
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||||
|
|
||||||
|
for node in self.nodes:
|
||||||
|
node_type = node.get("type")
|
||||||
|
node_id = node.get("id")
|
||||||
|
cycle_node = node.get("cycle")
|
||||||
|
if cycle_node:
|
||||||
|
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
||||||
|
if not self.subgraph:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 记录 start 和 end 节点 ID
|
||||||
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
self.start_node_id = node_id
|
||||||
|
elif node_type == NodeType.END:
|
||||||
|
self.end_node_ids.append(node_id)
|
||||||
|
|
||||||
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
|
||||||
|
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
|
||||||
|
|
||||||
|
# Find all edges whose source is the current node
|
||||||
|
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||||
|
|
||||||
|
# Iterate over each branch
|
||||||
|
for idx in range(len(related_edge)):
|
||||||
|
# Generate a condition expression for each edge
|
||||||
|
# Used later to determine which branch to take based on the node's output
|
||||||
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
|
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||||
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
|
if node_instance:
|
||||||
|
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||||
|
if self.stream and node_id in end_prefixes:
|
||||||
|
# 将 End 前缀配置注入到节点实例
|
||||||
|
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||||
|
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||||
|
|
||||||
|
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||||
|
if self.stream:
|
||||||
|
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||||
|
if node_id in adjacent_and_referenced:
|
||||||
|
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||||
|
|
||||||
|
# 包装节点的 run 方法
|
||||||
|
# 使用函数工厂避免闭包问题
|
||||||
|
if self.stream:
|
||||||
|
# 流式模式:创建 async generator 函数
|
||||||
|
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||||
|
def make_stream_func(inst):
|
||||||
|
async def node_func(state: WorkflowState):
|
||||||
|
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||||
|
async for item in inst.run_stream(state):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return node_func
|
||||||
|
|
||||||
|
self.graph.add_node(node_id, make_stream_func(node_instance))
|
||||||
|
else:
|
||||||
|
# 非流式模式:创建 async function
|
||||||
|
def make_func(inst):
|
||||||
|
async def node_func(state: WorkflowState):
|
||||||
|
return await inst.run(state)
|
||||||
|
|
||||||
|
return node_func
|
||||||
|
|
||||||
|
self.graph.add_node(node_id, make_func(node_instance))
|
||||||
|
|
||||||
|
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
|
||||||
|
|
||||||
|
def add_edges(self):
|
||||||
|
if self.start_node_id:
|
||||||
|
self.graph.add_edge(START, self.start_node_id)
|
||||||
|
logger.debug(f"添加边: START -> {self.start_node_id}")
|
||||||
|
|
||||||
|
for edge in self.edges:
|
||||||
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
|
edge_type = edge.get("type")
|
||||||
|
condition = edge.get("condition")
|
||||||
|
|
||||||
|
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||||
|
if source == self.start_node_id:
|
||||||
|
# 但要连接 start 到下一个节点
|
||||||
|
self.graph.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# # 处理到 end 节点的边
|
||||||
|
# if target in end_node_ids:
|
||||||
|
# # 连接到 end 节点
|
||||||
|
# workflow.add_edge(source, target)
|
||||||
|
# logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# 跳过错误边(在节点内部处理)
|
||||||
|
if edge_type == "error":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if condition:
|
||||||
|
# 条件边
|
||||||
|
def make_router(cond, tgt):
|
||||||
|
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
||||||
|
|
||||||
|
def router_fn(state: WorkflowState):
|
||||||
|
if evaluate_condition(
|
||||||
|
cond,
|
||||||
|
state.get("variables", {}),
|
||||||
|
state.get("runtime_vars", {}),
|
||||||
|
{
|
||||||
|
"execution_id": state.get("execution_id"),
|
||||||
|
"workspace_id": state.get("workspace_id"),
|
||||||
|
"user_id": state.get("user_id")
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return tgt
|
||||||
|
return END
|
||||||
|
|
||||||
|
# 动态修改函数名,避免重复
|
||||||
|
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
|
||||||
|
return router_fn
|
||||||
|
|
||||||
|
router_fn = make_router(condition, target)
|
||||||
|
self.graph.add_conditional_edges(source, router_fn)
|
||||||
|
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||||
|
else:
|
||||||
|
# 普通边
|
||||||
|
self.graph.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
|
||||||
|
# 从 end 节点连接到 END
|
||||||
|
for end_node_id in self.end_node_ids:
|
||||||
|
self.graph.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"添加边: {end_node_id} -> END")
|
||||||
|
return
|
||||||
|
|
||||||
|
def build(self) -> CompiledStateGraph:
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
|
self.add_edges() # 添加边必须在添加节点之后
|
||||||
|
return self.graph.compile()
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
@@ -19,7 +21,7 @@ class AssignmentItem(BaseModel):
|
|||||||
description="Assignment operator",
|
description="Assignment operator",
|
||||||
)
|
)
|
||||||
|
|
||||||
value: str | list[str] = Field(
|
value: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Value(s) to assign to the variable(s)",
|
description="Value(s) to assign to the variable(s)",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
|
||||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||||
@@ -29,6 +28,7 @@ class AssignerNode(BaseNode):
|
|||||||
None or the result of the assignment operation.
|
None or the result of the assignment operation.
|
||||||
"""
|
"""
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行")
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
for assignment in self.typed_config.assignments:
|
for assignment in self.typed_config.assignments:
|
||||||
# Get the target variable selector (e.g., "conv.test")
|
# Get the target variable selector (e.g., "conv.test")
|
||||||
@@ -45,14 +45,13 @@ class AssignerNode(BaseNode):
|
|||||||
|
|
||||||
# Get the value or expression to assign
|
# Get the value or expression to assign
|
||||||
value = assignment.value
|
value = assignment.value
|
||||||
if isinstance(value, list):
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
value = '.'.join(value)
|
if isinstance(value, str):
|
||||||
value = ExpressionEvaluator.evaluate(
|
expression = re.match(pattern, value)
|
||||||
expression=value,
|
if expression:
|
||||||
variables=pool.get_all_conversation_vars(),
|
expression = expression.group(1)
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
expression = re.sub(pattern, r"\1", expression).strip()
|
||||||
system_vars=pool.get_all_system_vars(),
|
value = self.get_variable(expression, state)
|
||||||
)
|
|
||||||
|
|
||||||
# Select the appropriate assignment operator instance based on the target variable type
|
# Select the appropriate assignment operator instance based on the target variable type
|
||||||
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
||||||
@@ -63,6 +62,8 @@ class AssignerNode(BaseNode):
|
|||||||
|
|
||||||
# Execute the configured assignment operation
|
# Execute the configured assignment operation
|
||||||
match assignment.operation:
|
match assignment.operation:
|
||||||
|
case AssignmentOperator.COVER:
|
||||||
|
operator.assign()
|
||||||
case AssignmentOperator.ASSIGN:
|
case AssignmentOperator.ASSIGN:
|
||||||
operator.assign()
|
operator.assign()
|
||||||
case AssignmentOperator.CLEAR:
|
case AssignmentOperator.CLEAR:
|
||||||
|
|||||||
@@ -4,8 +4,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
|
||||||
|
|
||||||
class VariableType(StrEnum):
|
class VariableType(StrEnum):
|
||||||
@@ -22,6 +25,57 @@ class VariableType(StrEnum):
|
|||||||
ARRAY_OBJECT = "array[object]"
|
ARRAY_OBJECT = "array[object]"
|
||||||
|
|
||||||
|
|
||||||
|
class TypedVariable(BaseModel):
|
||||||
|
"""
|
||||||
|
TODO: 强类型限制
|
||||||
|
Strongly typed variable that validates value on assignment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
value: Any = Field(..., description="Variable value")
|
||||||
|
type: VariableType = Field(..., description="Declared type of the variable")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name == "value":
|
||||||
|
self._validate_value(value)
|
||||||
|
if name == "type":
|
||||||
|
raise RuntimeError("Cannot modify variable type at runtime")
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
def _validate_value(self, v: Any):
|
||||||
|
t = self.type
|
||||||
|
match t:
|
||||||
|
case VariableType.STRING:
|
||||||
|
if not isinstance(v, str):
|
||||||
|
raise TypeError("Variable value does not match type STRING")
|
||||||
|
case VariableType.BOOLEAN:
|
||||||
|
if not isinstance(v, bool):
|
||||||
|
raise TypeError("Variable value does not match type BOOLEAN")
|
||||||
|
case VariableType.NUMBER:
|
||||||
|
if not isinstance(v, (int, float)):
|
||||||
|
raise TypeError("Variable value does not match type NUMBER")
|
||||||
|
case VariableType.OBJECT:
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
raise TypeError("Variable value does not match type OBJECT")
|
||||||
|
case VariableType.ARRAY_STRING:
|
||||||
|
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
|
||||||
|
raise TypeError("Variable value does not match type ARRAY_STRING")
|
||||||
|
case VariableType.ARRAY_NUMBER:
|
||||||
|
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
|
||||||
|
raise TypeError("Variable value does not match type ARRAY_NUMBER")
|
||||||
|
case VariableType.ARRAY_BOOLEAN:
|
||||||
|
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
|
||||||
|
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
|
||||||
|
case VariableType.ARRAY_OBJECT:
|
||||||
|
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
|
||||||
|
raise TypeError("Variable value does not match type ARRAY_OBJECT")
|
||||||
|
case _:
|
||||||
|
raise TypeError(f"Unknown variable type: {t}")
|
||||||
|
|
||||||
|
|
||||||
class VariableDefinition(BaseModel):
|
class VariableDefinition(BaseModel):
|
||||||
"""变量定义
|
"""变量定义
|
||||||
|
|
||||||
|
|||||||
@@ -356,7 +356,8 @@ class BaseNode(ABC):
|
|||||||
**final_output,
|
**final_output,
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
}
|
},
|
||||||
|
"looping": state["looping"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add streaming buffer for non-End nodes
|
# Add streaming buffer for non-End nodes
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
|
|
||||||
|
|
||||||
class CycleVariable(BaseNodeConfig):
|
class CycleVariable(BaseNodeConfig):
|
||||||
@@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig):
|
|||||||
...,
|
...,
|
||||||
description="Name of the loop variable"
|
description="Name of the loop variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
type: VariableType = Field(
|
type: VariableType = Field(
|
||||||
...,
|
...,
|
||||||
description="Data type of the loop variable"
|
description="Data type of the loop variable"
|
||||||
)
|
)
|
||||||
value: str = Field(
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Input type of the loop variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
value: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Initial or current value of the loop variable"
|
description="Initial or current value of the loop variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionDetail(BaseModel):
|
class ConditionDetail(BaseModel):
|
||||||
comparison_operator: ComparisonOperator = Field(
|
operator: ComparisonOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Operator used to compare the left and right operands"
|
description="Operator used to compare the left and right operands"
|
||||||
)
|
)
|
||||||
@@ -30,11 +39,16 @@ class ConditionDetail(BaseModel):
|
|||||||
description="Left-hand operand of the comparison expression"
|
description="Left-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
right: str = Field(
|
right: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Right-hand operand of the comparison expression"
|
description="Right-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Input type of the loop variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionsConfig(BaseModel):
|
class ConditionsConfig(BaseModel):
|
||||||
"""Configuration for loop condition evaluation"""
|
"""Configuration for loop condition evaluation"""
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from typing import Any
|
|||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
|
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||||
|
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class LoopRuntime:
|
class LoopRuntime:
|
||||||
"""
|
"""
|
||||||
Runtime executor for loop nodes in a workflow.
|
Runtime executor for a loop node in a workflow graph.
|
||||||
|
|
||||||
Handles iterative execution of a loop node according to defined loop variables
|
This class is responsible for executing a loop node at runtime:
|
||||||
and conditional expressions. Supports maximum loop count and loop control
|
- Initializing loop-scoped variables
|
||||||
through the workflow state.
|
- Evaluating loop continuation conditions
|
||||||
|
- Repeatedly invoking a compiled sub-graph
|
||||||
|
- Enforcing maximum loop count and external stop signals
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -29,13 +32,13 @@ class LoopRuntime:
|
|||||||
state: WorkflowState,
|
state: WorkflowState,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the loop runtime.
|
Initialize the loop runtime executor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Compiled workflow graph capable of async invocation.
|
graph: A compiled LangGraph state graph representing the loop body.
|
||||||
node_id: Unique identifier of the loop node.
|
node_id: The unique identifier of the loop node in the workflow.
|
||||||
config: Dictionary containing loop node configuration.
|
config: Raw configuration dictionary for the loop node.
|
||||||
state: Current workflow state at the point of loop execution.
|
state: The current workflow state before entering the loop.
|
||||||
"""
|
"""
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.state = state
|
self.state = state
|
||||||
@@ -46,12 +49,15 @@ class LoopRuntime:
|
|||||||
"""
|
"""
|
||||||
Initialize workflow state for loop execution.
|
Initialize workflow state for loop execution.
|
||||||
|
|
||||||
- Evaluates initial values of loop variables.
|
This method:
|
||||||
- Stores loop variables in runtime_vars and node_outputs.
|
- Evaluates initial values of loop variables
|
||||||
- Marks the loop as active by setting 'looping' to True.
|
- Stores loop variables into both `runtime_vars` and `node_outputs`
|
||||||
|
under the current loop node's scope
|
||||||
|
- Creates a shallow copy of the workflow state
|
||||||
|
- Marks the loop as active by setting `looping = True`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A copy of the workflow state prepared for the loop execution.
|
WorkflowState: A prepared workflow state used for loop execution.
|
||||||
"""
|
"""
|
||||||
pool = VariablePool(self.state)
|
pool = VariablePool(self.state)
|
||||||
# 循环变量
|
# 循环变量
|
||||||
@@ -61,7 +67,7 @@ class LoopRuntime:
|
|||||||
variables=pool.get_all_conversation_vars(),
|
variables=pool.get_all_conversation_vars(),
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
system_vars=pool.get_all_system_vars(),
|
system_vars=pool.get_all_system_vars(),
|
||||||
)
|
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||||
for variable in self.typed_config.cycle_vars
|
for variable in self.typed_config.cycle_vars
|
||||||
}
|
}
|
||||||
self.state["node_outputs"][self.node_id] = {
|
self.state["node_outputs"][self.node_id] = {
|
||||||
@@ -70,7 +76,7 @@ class LoopRuntime:
|
|||||||
variables=pool.get_all_conversation_vars(),
|
variables=pool.get_all_conversation_vars(),
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
system_vars=pool.get_all_system_vars(),
|
system_vars=pool.get_all_system_vars(),
|
||||||
)
|
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||||
for variable in self.typed_config.cycle_vars
|
for variable in self.typed_config.cycle_vars
|
||||||
}
|
}
|
||||||
loopstate = WorkflowState(
|
loopstate = WorkflowState(
|
||||||
@@ -79,49 +85,93 @@ class LoopRuntime:
|
|||||||
loopstate["looping"] = True
|
loopstate["looping"] = True
|
||||||
return loopstate
|
return loopstate
|
||||||
|
|
||||||
def _get_loop_expression(self):
|
@staticmethod
|
||||||
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
"""
|
"""
|
||||||
Build the Python boolean expression for evaluating the loop condition.
|
Dispatch and execute a comparison operator against a resolved
|
||||||
|
CompareOperatorInstance.
|
||||||
|
|
||||||
- Converts each condition in the loop configuration into a Python expression string.
|
Args:
|
||||||
- Combines multiple conditions with the configured logical operator (AND/OR).
|
operator: A ComparisonOperator enum value.
|
||||||
|
instance: A CompareOperatorInstance bound to concrete operands.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string representing the combined loop condition expression.
|
Any: The evaluation result, typically a boolean.
|
||||||
"""
|
"""
|
||||||
branch_conditions = [
|
match operator:
|
||||||
ConditionExpressionBuilder(
|
case ComparisonOperator.EMPTY:
|
||||||
left=condition.left,
|
return instance.empty()
|
||||||
operator=condition.comparison_operator,
|
case ComparisonOperator.NOT_EMPTY:
|
||||||
right=condition.right
|
return instance.not_empty()
|
||||||
).build()
|
case ComparisonOperator.CONTAINS:
|
||||||
for condition in self.typed_config.condition.expressions
|
return instance.contains()
|
||||||
]
|
case ComparisonOperator.NOT_CONTAINS:
|
||||||
if len(branch_conditions) > 1:
|
return instance.not_contains()
|
||||||
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
|
case ComparisonOperator.START_WITH:
|
||||||
else:
|
return instance.startswith()
|
||||||
combined_condition = branch_conditions[0]
|
case ComparisonOperator.END_WITH:
|
||||||
|
return instance.endswith()
|
||||||
|
case ComparisonOperator.EQ:
|
||||||
|
return instance.eq()
|
||||||
|
case ComparisonOperator.NE:
|
||||||
|
return instance.ne()
|
||||||
|
case ComparisonOperator.LT:
|
||||||
|
return instance.lt()
|
||||||
|
case ComparisonOperator.LE:
|
||||||
|
return instance.le()
|
||||||
|
case ComparisonOperator.GT:
|
||||||
|
return instance.gt()
|
||||||
|
case ComparisonOperator.GE:
|
||||||
|
return instance.ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {operator}")
|
||||||
|
|
||||||
return combined_condition
|
def evaluate_conditional(self, state) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate the loop continuation condition at runtime.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
- Resolves all condition expressions against the current workflow state
|
||||||
|
- Evaluates each comparison expression immediately
|
||||||
|
- Combines results using the configured logical operator (AND / OR)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current workflow state during loop execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the loop should continue, False otherwise.
|
||||||
|
"""
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
for expression in self.typed_config.condition.expressions:
|
||||||
|
left_value = VariablePool(state).get(expression.left)
|
||||||
|
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||||
|
VariablePool(state),
|
||||||
|
expression.left,
|
||||||
|
expression.right,
|
||||||
|
expression.input_type
|
||||||
|
)
|
||||||
|
conditions.append(self._evaluate(expression.operator, evaluator))
|
||||||
|
if self.typed_config.condition.logical_operator == LogicOperator.AND:
|
||||||
|
return all(conditions)
|
||||||
|
else:
|
||||||
|
return any(conditions)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Execute the loop node until the condition is no longer met, the loop is
|
Execute the loop node until termination conditions are met.
|
||||||
manually stopped, or the maximum loop count is reached.
|
|
||||||
|
The loop terminates when any of the following occurs:
|
||||||
|
- The loop condition evaluates to False
|
||||||
|
- The `looping` flag in the workflow state is set to False
|
||||||
|
- The maximum loop count is reached
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The final runtime variables of this loop node after completion.
|
dict[str, Any]: The final runtime variables of this loop node.
|
||||||
"""
|
"""
|
||||||
loopstate = self._init_loop_state()
|
loopstate = self._init_loop_state()
|
||||||
expression = self._get_loop_expression()
|
|
||||||
loop_variable_pool = VariablePool(loopstate)
|
|
||||||
loop_time = self.typed_config.max_loop
|
loop_time = self.typed_config.max_loop
|
||||||
while evaluate_condition(
|
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
|
||||||
expression=expression,
|
|
||||||
variables=loop_variable_pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
|
||||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
|
||||||
) and loopstate["looping"] and loop_time > 0:
|
|
||||||
logger.info(f"loop node {self.node_id}: running")
|
logger.info(f"loop node {self.node_id}: running")
|
||||||
await self.graph.ainvoke(loopstate)
|
await self.graph.ainvoke(loopstate)
|
||||||
loop_time -= 1
|
loop_time -= 1
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||||
@@ -17,12 +16,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class CycleGraphNode(BaseNode):
|
class CycleGraphNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
Node representing a cycle (loop) subgraph within the workflow.
|
Node representing a cyclic (loop or iteration) subgraph within the workflow.
|
||||||
|
|
||||||
This node manages internal loop/iteration nodes, builds a subgraph
|
A CycleGraphNode is a structural node that:
|
||||||
for execution, handles conditional routing, and executes loop
|
- Extracts a group of nodes marked as belonging to the same cycle
|
||||||
or iteration logic based on node type.
|
- Builds an isolated internal StateGraph (subgraph)
|
||||||
|
- Delegates runtime execution to LoopRuntime or IterationRuntime
|
||||||
|
depending on the node type
|
||||||
|
|
||||||
|
This node itself does NOT execute business logic directly.
|
||||||
|
It acts as a container and execution controller for a subgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||||
@@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
Extract cycle nodes and internal edges from the workflow configuration,
|
Extract cycle-scoped nodes and internal edges from the workflow configuration.
|
||||||
removing them from the global workflow.
|
|
||||||
|
|
||||||
Raises:
|
This method:
|
||||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
- Identifies all nodes marked with `cycle == self.node_id`
|
||||||
|
- Collects edges that fully connect cycle nodes
|
||||||
|
- Removes extracted nodes and edges from the global workflow configuration
|
||||||
|
|
||||||
|
Safety check:
|
||||||
|
- Raises an error if a cycle node is connected to an external node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
tuple[list, list]:
|
||||||
- cycle_nodes: List of removed nodes
|
- cycle_nodes: Nodes belonging to this cycle
|
||||||
- cycle_edges: List of removed edges
|
- cycle_edges: Edges connecting nodes within the cycle
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a cycle node is improperly connected to an external node.
|
||||||
"""
|
"""
|
||||||
nodes = self.workflow_config.get("nodes", [])
|
nodes = self.workflow_config.get("nodes", [])
|
||||||
edges = self.workflow_config.get("edges", [])
|
edges = self.workflow_config.get("edges", [])
|
||||||
@@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
def create_node(self):
|
|
||||||
"""
|
|
||||||
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
|
|
||||||
|
|
||||||
Special handling is applied for conditional nodes to generate
|
|
||||||
edge conditions based on node outputs.
|
|
||||||
"""
|
|
||||||
from app.core.workflow.nodes import NodeFactory
|
|
||||||
for node in self.cycle_nodes:
|
|
||||||
node_type = node.get("type")
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
if node_type == NodeType.CYCLE_START:
|
|
||||||
self.start_node_id = node_id
|
|
||||||
continue
|
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
|
||||||
|
|
||||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
|
||||||
expressions = node_instance.build_conditional_edge_expressions()
|
|
||||||
|
|
||||||
# Number of branches, usually matches the number of conditional expressions
|
|
||||||
branch_number = len(expressions)
|
|
||||||
|
|
||||||
# Find all edges whose source is the current node
|
|
||||||
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
|
|
||||||
|
|
||||||
# Iterate over each branch
|
|
||||||
for idx in range(branch_number):
|
|
||||||
# Generate a condition expression for each edge
|
|
||||||
# Used later to determine which branch to take based on the node's output
|
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
|
||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
|
||||||
|
|
||||||
def make_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
return await inst.run(state)
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
self.graph.add_node(node_id, make_func(node_instance))
|
|
||||||
|
|
||||||
def create_edge(self):
|
|
||||||
"""
|
|
||||||
Connect nodes within the cycle subgraph by adding edges to the internal graph.
|
|
||||||
|
|
||||||
Conditional edges are routed based on evaluated expressions.
|
|
||||||
Start and end nodes are connected to global START and END nodes.
|
|
||||||
"""
|
|
||||||
for edge in self.cycle_edges:
|
|
||||||
source = edge.get("source")
|
|
||||||
target = edge.get("target")
|
|
||||||
edge_type = edge.get("type")
|
|
||||||
condition = edge.get("condition")
|
|
||||||
|
|
||||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
|
||||||
if source == self.start_node_id:
|
|
||||||
# 但要连接 start 到下一个节点
|
|
||||||
self.graph.add_edge(START, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if condition:
|
|
||||||
# 条件边
|
|
||||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
|
||||||
"""条件路由函数"""
|
|
||||||
if evaluate_condition(
|
|
||||||
cond,
|
|
||||||
state.get("variables", {}),
|
|
||||||
state.get("node_outputs", {}),
|
|
||||||
{
|
|
||||||
"execution_id": state.get("execution_id"),
|
|
||||||
"workspace_id": state.get("workspace_id"),
|
|
||||||
"user_id": state.get("user_id")
|
|
||||||
}
|
|
||||||
):
|
|
||||||
return tgt
|
|
||||||
return END # 条件不满足,结束
|
|
||||||
|
|
||||||
self.graph.add_conditional_edges(source, router)
|
|
||||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
|
||||||
else:
|
|
||||||
# 普通边
|
|
||||||
self.graph.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
|
|
||||||
# 从 end 节点连接到 END
|
|
||||||
for end_node_id in self.end_node_ids:
|
|
||||||
self.graph.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"添加边: {end_node_id} -> END")
|
|
||||||
|
|
||||||
def build_graph(self):
|
def build_graph(self):
|
||||||
"""
|
"""
|
||||||
Build the internal subgraph for the cycle node.
|
Build and compile the internal subgraph for this cycle node.
|
||||||
|
|
||||||
Steps:
|
Steps:
|
||||||
1. Extract cycle nodes and edges.
|
1. Extract cycle nodes and internal edges from the workflow
|
||||||
2. Create node instances and add them to the graph.
|
2. Construct a StateGraph using GraphBuilder in subgraph mode
|
||||||
3. Connect edges and conditional routes.
|
3. Compile the graph for runtime execution
|
||||||
4. Compile the graph for execution.
|
|
||||||
"""
|
"""
|
||||||
self.graph = StateGraph(WorkflowState)
|
from app.core.workflow.graph_builder import GraphBuilder
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.create_node()
|
self.graph = GraphBuilder(
|
||||||
self.create_edge()
|
{
|
||||||
self.graph = self.graph.compile()
|
"nodes": self.cycle_nodes,
|
||||||
|
"edges": self.cycle_edges,
|
||||||
|
},
|
||||||
|
subgraph=True
|
||||||
|
).build()
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the cycle node at runtime.
|
Execute the cycle node at runtime.
|
||||||
|
|
||||||
Depending on the node type, runs either a loop (LoopRuntime)
|
Based on the node type:
|
||||||
or an iteration (IterationRuntime) over the internal subgraph.
|
- LOOP: Executes LoopRuntime, repeatedly invoking the subgraph
|
||||||
|
- ITERATION: Executes IterationRuntime, iterating over a collection
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Current workflow state.
|
state: The current workflow state when entering the cycle node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Runtime result of the cycle, typically the final loop/iteration variables.
|
Any: The runtime result produced by the loop or iteration executor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If node type is unrecognized.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class EndNode(BaseNode):
|
|||||||
引用的节点 ID 列表
|
引用的节点 ID 列表
|
||||||
"""
|
"""
|
||||||
# 匹配 {{node_id.xxx}} 格式
|
# 匹配 {{node_id.xxx}} 格式
|
||||||
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}'
|
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||||
matches = re.findall(pattern, template)
|
matches = re.findall(pattern, template)
|
||||||
return list(set(matches)) # 去重
|
return list(set(matches)) # 去重
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class AssignmentOperator(StrEnum):
|
class AssignmentOperator(StrEnum):
|
||||||
ASSIGN = "assign"
|
COVER = "cover" # 覆盖
|
||||||
|
ASSIGN = "assign" # 设置
|
||||||
CLEAR = "clear"
|
CLEAR = "clear"
|
||||||
|
|
||||||
ADD = "add" # +=
|
ADD = "add" # +=
|
||||||
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
|
|||||||
NONE = "none"
|
NONE = "none"
|
||||||
DEFAULT = "default"
|
DEFAULT = "default"
|
||||||
BRANCH = "branch"
|
BRANCH = "branch"
|
||||||
|
|
||||||
|
|
||||||
|
class ValueInputType(StrEnum):
|
||||||
|
VARIABLE = "Variable"
|
||||||
|
CONSTANT = "Constant"
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class HttpContentTypeConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
data: list[HttpFormData] | dict | str = Field(
|
data: list[HttpFormData] | dict | str = Field(
|
||||||
...,
|
default="",
|
||||||
description="Data of the HTTP request body; type depends on content_type",
|
description="Data of the HTTP request body; type depends on content_type",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -98,6 +98,10 @@ class HttpTimeOutConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class HttpRetryConfig(BaseModel):
|
class HttpRetryConfig(BaseModel):
|
||||||
|
enable: bool = Field(
|
||||||
|
...,
|
||||||
|
description="Enable/disable retry logic",
|
||||||
|
)
|
||||||
max_attempts: int = Field(
|
max_attempts: int = Field(
|
||||||
default=1,
|
default=1,
|
||||||
description="Maximum number of retry attempts for failed requests",
|
description="Maximum number of retry attempts for failed requests",
|
||||||
@@ -124,6 +128,11 @@ class HttpErrorDefaultTamplete(BaseModel):
|
|||||||
description="Default HTTP headers returned on error",
|
description="Default HTTP headers returned on error",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output: str = Field(
|
||||||
|
default="SUCCESS",
|
||||||
|
description="HTTP response body",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpErrorHandleConfig(BaseModel):
|
class HttpErrorHandleConfig(BaseModel):
|
||||||
method: HttpErrorHandle = Field(
|
method: HttpErrorHandle = Field(
|
||||||
@@ -131,8 +140,8 @@ class HttpErrorHandleConfig(BaseModel):
|
|||||||
description="Error handling strategy: 'none', 'default', or 'branch'",
|
description="Error handling strategy: 'none', 'default', or 'branch'",
|
||||||
)
|
)
|
||||||
|
|
||||||
default: HttpErrorDefaultTamplete = Field(
|
default: HttpErrorDefaultTamplete | None = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Default response template for error handling",
|
description="Default response template for error handling",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -165,24 +165,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||||
|
|
||||||
def build_conditional_edge_expressions(self):
|
|
||||||
"""
|
|
||||||
Build conditional edge expressions for workflow branching.
|
|
||||||
|
|
||||||
When the HTTP error handling strategy is set to `BRANCH`,
|
|
||||||
this node exposes a single conditional output labeled "ERROR".
|
|
||||||
The workflow engine uses this output to create an explicit
|
|
||||||
error-handling branch for downstream nodes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]:
|
|
||||||
- ["ERROR"] if error handling strategy is BRANCH
|
|
||||||
- An empty list if no conditional branching is required
|
|
||||||
"""
|
|
||||||
if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH:
|
|
||||||
return ["ERROR"]
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict | str:
|
async def execute(self, state: WorkflowState) -> dict | str:
|
||||||
"""
|
"""
|
||||||
Execute the HTTP request node.
|
Execute the HTTP request node.
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""Condition Configuration"""
|
"""Condition Configuration"""
|
||||||
|
from typing import Any
|
||||||
from pydantic import Field, BaseModel, field_validator
|
from pydantic import Field, BaseModel, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
|
|
||||||
|
|
||||||
class ConditionDetail(BaseModel):
|
class ConditionDetail(BaseModel):
|
||||||
comparison_operator: ComparisonOperator = Field(
|
operator: ComparisonOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Comparison operator used to evaluate the condition"
|
description="Comparison operator used to evaluate the condition"
|
||||||
)
|
)
|
||||||
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
|
|||||||
description="Value to compare against"
|
description="Value to compare against"
|
||||||
)
|
)
|
||||||
|
|
||||||
right: str = Field(
|
right: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Value to compare with"
|
description="Value to compare with"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Value input type for comparison"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionBranchConfig(BaseModel):
|
class ConditionBranchConfig(BaseModel):
|
||||||
"""Configuration for a conditional branch"""
|
"""Configuration for a conditional branch"""
|
||||||
|
|
||||||
logical_operator: LogicOperator = Field(
|
logical_operator: LogicOperator = Field(
|
||||||
default=LogicOperator.AND.value,
|
default=LogicOperator.AND,
|
||||||
description="Logical operator used to combine multiple condition expressions"
|
description="Logical operator used to combine multiple condition expressions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
|
|||||||
self.typed_config = IfElseNodeConfig(**self.config)
|
self.typed_config = IfElseNodeConfig(**self.config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_condition_expression(
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
condition: ConditionDetail,
|
match operator:
|
||||||
) -> str:
|
case ComparisonOperator.EMPTY:
|
||||||
"""
|
return instance.empty()
|
||||||
Build a single boolean condition expression string.
|
case ComparisonOperator.NOT_EMPTY:
|
||||||
|
return instance.not_empty()
|
||||||
|
case ComparisonOperator.CONTAINS:
|
||||||
|
return instance.contains()
|
||||||
|
case ComparisonOperator.NOT_CONTAINS:
|
||||||
|
return instance.not_contains()
|
||||||
|
case ComparisonOperator.START_WITH:
|
||||||
|
return instance.startswith()
|
||||||
|
case ComparisonOperator.END_WITH:
|
||||||
|
return instance.endswith()
|
||||||
|
case ComparisonOperator.EQ:
|
||||||
|
return instance.eq()
|
||||||
|
case ComparisonOperator.NE:
|
||||||
|
return instance.ne()
|
||||||
|
case ComparisonOperator.LT:
|
||||||
|
return instance.lt()
|
||||||
|
case ComparisonOperator.LE:
|
||||||
|
return instance.le()
|
||||||
|
case ComparisonOperator.GT:
|
||||||
|
return instance.gt()
|
||||||
|
case ComparisonOperator.GE:
|
||||||
|
return instance.ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {operator}")
|
||||||
|
|
||||||
This method does NOT evaluate the condition.
|
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
|
||||||
It only generates a valid Python boolean expression string
|
|
||||||
(e.g. "x > 10", "'a' in name") that can later be used
|
|
||||||
in a conditional edge or evaluated by the workflow engine.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition (ConditionDetail): Definition of a single comparison condition.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A Python boolean expression string.
|
|
||||||
"""
|
|
||||||
return ConditionExpressionBuilder(
|
|
||||||
left=condition.left,
|
|
||||||
operator=condition.comparison_operator,
|
|
||||||
right=condition.right
|
|
||||||
).build()
|
|
||||||
|
|
||||||
def build_conditional_edge_expressions(self) -> list[str]:
|
|
||||||
"""
|
"""
|
||||||
Build conditional edge expressions for the If-Else node.
|
Build conditional edge expressions for the If-Else node.
|
||||||
|
|
||||||
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
|
|||||||
|
|
||||||
for case_branch in self.typed_config.cases:
|
for case_branch in self.typed_config.cases:
|
||||||
branch_index += 1
|
branch_index += 1
|
||||||
|
branch_result = []
|
||||||
branch_conditions = [
|
for expression in case_branch.expressions:
|
||||||
self._build_condition_expression(condition)
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
for condition in case_branch.expressions
|
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||||
]
|
left_value = self.get_variable(left_string, state)
|
||||||
if len(branch_conditions) > 1:
|
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
self.get_variable_pool(state),
|
||||||
|
expression.left,
|
||||||
|
expression.right,
|
||||||
|
expression.input_type
|
||||||
|
)
|
||||||
|
branch_result.append(self._evaluate(expression.operator, evaluator))
|
||||||
|
if case_branch.logical_operator == LogicOperator.AND:
|
||||||
|
conditions.append(all(branch_result))
|
||||||
else:
|
else:
|
||||||
combined_condition = branch_conditions[0]
|
condition_res = any(branch_result)
|
||||||
conditions.append(combined_condition)
|
conditions.append(condition_res)
|
||||||
|
if condition_res:
|
||||||
|
return conditions
|
||||||
|
|
||||||
# Default fallback branch
|
# Default fallback branch
|
||||||
conditions.append("True")
|
conditions.append(True)
|
||||||
|
|
||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||||
"""
|
"""
|
||||||
expressions = self.build_conditional_edge_expressions()
|
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||||
|
# TODO: 变量类型及文本类型解析
|
||||||
for i in range(len(expressions)):
|
for i in range(len(expressions)):
|
||||||
logger.info(expressions[i])
|
if expressions[i]:
|
||||||
if self._evaluate_condition(expressions[i], state):
|
|
||||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||||
return f'CASE{i + 1}'
|
return f'CASE{i + 1}'
|
||||||
return f'CASE{len(expressions)}'
|
return f'CASE{len(expressions)}'
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class NodeFactory:
|
|||||||
NodeType.LOOP: CycleGraphNode,
|
NodeType.LOOP: CycleGraphNode,
|
||||||
NodeType.ITERATION: CycleGraphNode,
|
NodeType.ITERATION: CycleGraphNode,
|
||||||
NodeType.BREAK: BreakNode,
|
NodeType.BREAK: BreakNode,
|
||||||
|
NodeType.CYCLE_START: StartNode,
|
||||||
NodeType.TOOL: ToolNode,
|
NodeType.TOOL: ToolNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,73 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Union, Type
|
from typing import Union, Type, NoReturn
|
||||||
|
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
|
from app.core.workflow.nodes.enums import ValueInputType
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTransformer:
|
||||||
|
@classmethod
|
||||||
|
def _fail(cls, value, target) -> NoReturn:
|
||||||
|
raise TypeError(f"Cannot convert {value!r} to {target} type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _json_load(cls, value, target):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except Exception:
|
||||||
|
cls._fail(value, target)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def transform(cls, variable_literal: str | bool, target_type: VariableType):
|
||||||
|
match target_type:
|
||||||
|
case VariableType.STRING:
|
||||||
|
return str(variable_literal)
|
||||||
|
|
||||||
|
case VariableType.NUMBER:
|
||||||
|
for caster in (int, float):
|
||||||
|
try:
|
||||||
|
return caster(variable_literal)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.BOOLEAN:
|
||||||
|
if isinstance(variable_literal, bool):
|
||||||
|
return variable_literal
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.OBJECT:
|
||||||
|
obj = cls._json_load(variable_literal, target_type)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_BOOLEAN:
|
||||||
|
return cls._parse_list(variable_literal, bool, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_NUMBER:
|
||||||
|
return cls._parse_list(variable_literal, (int, float), target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_STRING:
|
||||||
|
return cls._parse_list(variable_literal, str, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_OBJECT:
|
||||||
|
return cls._parse_list(variable_literal, dict, target_type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise TypeError("Invalid type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_list(cls, value, item_type, target):
|
||||||
|
arr = cls._json_load(value, target)
|
||||||
|
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
|
||||||
|
return arr
|
||||||
|
cls._fail(value, target)
|
||||||
|
|
||||||
|
|
||||||
class OperatorBase(ABC):
|
class OperatorBase(ABC):
|
||||||
def __init__(self, pool: VariablePool, left_selector, right):
|
def __init__(self, pool: VariablePool, left_selector, right):
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
|
|||||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||||
|
|
||||||
if not no_right and not isinstance(self.right, self.type_limit):
|
if not no_right and not isinstance(self.right, self.type_limit):
|
||||||
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
|
raise TypeError(
|
||||||
|
f"The value assigned must be of {self.type_limit} type"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StringOperator(OperatorBase):
|
class StringOperator(OperatorBase):
|
||||||
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
|
|||||||
class ObjectOperator(OperatorBase):
|
class ObjectOperator(OperatorBase):
|
||||||
def __init__(self, pool: VariablePool, left_selector, right):
|
def __init__(self, pool: VariablePool, left_selector, right):
|
||||||
super().__init__(pool, left_selector, right)
|
super().__init__(pool, left_selector, right)
|
||||||
self.type_limit = object
|
self.type_limit = dict
|
||||||
|
|
||||||
def assign(self) -> None:
|
def assign(self) -> None:
|
||||||
self.check()
|
self.check()
|
||||||
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
|
|||||||
|
|
||||||
|
|
||||||
class AssignmentOperatorResolver:
|
class AssignmentOperatorResolver:
|
||||||
|
OPERATOR_MAP = {
|
||||||
|
str: StringOperator,
|
||||||
|
bool: BooleanOperator,
|
||||||
|
int: NumberOperator,
|
||||||
|
float: NumberOperator,
|
||||||
|
list: ArrayOperator,
|
||||||
|
dict: ObjectOperator,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_by_value(cls, value):
|
def resolve_by_value(cls, value):
|
||||||
if isinstance(value, str):
|
for t, op in cls.OPERATOR_MAP.items():
|
||||||
return StringOperator
|
if isinstance(value, t):
|
||||||
elif isinstance(value, bool):
|
return op
|
||||||
return BooleanOperator
|
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
return NumberOperator
|
|
||||||
elif isinstance(value, list):
|
|
||||||
return ArrayOperator
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
return ObjectOperator
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported variable type: {type(value)}")
|
|
||||||
|
|
||||||
|
|
||||||
AssignmentOperatorInstance = Union[
|
AssignmentOperatorInstance = Union[
|
||||||
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
|
|||||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||||
|
|
||||||
|
|
||||||
class ConditionExpressionBuilder:
|
class ConditionBase(ABC):
|
||||||
"""
|
type_limit: type[str, int, dict, list] = None
|
||||||
Build a Python boolean expression string based on a comparison operator.
|
|
||||||
|
|
||||||
This class does not evaluate the expression.
|
def __init__(
|
||||||
It only generates a valid Python expression string
|
self,
|
||||||
that can be evaluated later in a workflow context.
|
pool: VariablePool,
|
||||||
"""
|
left_selector,
|
||||||
|
right_selector: str,
|
||||||
|
input_type: ValueInputType
|
||||||
|
):
|
||||||
|
self.pool = pool
|
||||||
|
self.left_selector = left_selector
|
||||||
|
self.right_selector = right_selector
|
||||||
|
self.input_type = input_type
|
||||||
|
|
||||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
self.left_value = self.pool.get(self.left_selector)
|
||||||
self.left = left
|
self.right_value = self.resolve_right_literal_value()
|
||||||
self.operator = operator
|
|
||||||
self.right = right
|
|
||||||
|
|
||||||
def _empty(self):
|
self.type_limit = getattr(self, "type_limit", None)
|
||||||
return f"{self.left} == ''"
|
|
||||||
|
|
||||||
def _not_empty(self):
|
def resolve_right_literal_value(self):
|
||||||
return f"{self.left} != ''"
|
if self.input_type == ValueInputType.VARIABLE:
|
||||||
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||||
|
return self.pool.get(right_expression)
|
||||||
|
elif self.input_type == ValueInputType.CONSTANT:
|
||||||
|
return self.right_selector
|
||||||
|
raise RuntimeError("Unsupported variable type")
|
||||||
|
|
||||||
def _contains(self):
|
def check(self, no_right=False):
|
||||||
return f"{self.right} in {self.left}"
|
left = self.pool.get(self.left_selector.variable_selector)
|
||||||
|
if not isinstance(left, self.type_limit):
|
||||||
|
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
|
||||||
|
if not no_right:
|
||||||
|
right = self.resolve_right_literal_value()
|
||||||
|
if not isinstance(right, self.type_limit):
|
||||||
|
raise TypeError(
|
||||||
|
f"The compared variable must be of {self.type_limit} type"
|
||||||
|
)
|
||||||
|
|
||||||
def _not_contains(self):
|
|
||||||
return f"{self.right} not in {self.left}"
|
|
||||||
|
|
||||||
def _startswith(self):
|
class StringComparisonOperator(ConditionBase):
|
||||||
return f'{self.left}.startswith({self.right})'
|
type_limit = str
|
||||||
|
|
||||||
def _endswith(self):
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
return f'{self.left}.endswith({self.right})'
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
def _eq(self):
|
def empty(self):
|
||||||
return f"{self.left} == {self.right}"
|
self.check(no_right=True)
|
||||||
|
return self.left_value == ""
|
||||||
|
|
||||||
def _ne(self):
|
def not_empty(self):
|
||||||
return f"{self.left} != {self.right}"
|
return not self.empty()
|
||||||
|
|
||||||
def _lt(self):
|
def contains(self):
|
||||||
return f"{self.left} < {self.right}"
|
self.check()
|
||||||
|
return self.right_value in self.left_value
|
||||||
|
|
||||||
def _le(self):
|
def not_contains(self):
|
||||||
return f"{self.left} <= {self.right}"
|
return self.right_value not in self.left_value
|
||||||
|
|
||||||
def _gt(self):
|
def startswith(self):
|
||||||
return f"{self.left} > {self.right}"
|
self.check()
|
||||||
|
return self.left_value.startswith(self.right_value)
|
||||||
|
|
||||||
def _ge(self):
|
def endswith(self):
|
||||||
return f"{self.left} >= {self.right}"
|
return self.left_value.endswith(self.right_value)
|
||||||
|
|
||||||
def build(self):
|
def eq(self):
|
||||||
match self.operator:
|
return self.left_value == self.right_value
|
||||||
case ComparisonOperator.EMPTY:
|
|
||||||
return self._empty()
|
def ne(self):
|
||||||
case ComparisonOperator.NOT_EMPTY:
|
return self.left_value != self.right_value
|
||||||
return self._not_empty()
|
|
||||||
case ComparisonOperator.CONTAINS:
|
|
||||||
return self._contains()
|
class NumberComparisonOperator(ConditionBase):
|
||||||
case ComparisonOperator.NOT_CONTAINS:
|
type_limit = (int, float)
|
||||||
return self._not_contains()
|
|
||||||
case ComparisonOperator.START_WITH:
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
return self._startswith()
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
case ComparisonOperator.END_WITH:
|
|
||||||
return self._endswith()
|
def empty(self):
|
||||||
case ComparisonOperator.EQ:
|
return self.left_value == 0
|
||||||
return self._eq()
|
|
||||||
case ComparisonOperator.NE:
|
def not_empty(self):
|
||||||
return self._ne()
|
return self.left_value != 0
|
||||||
case ComparisonOperator.LT:
|
|
||||||
return self._lt()
|
def eq(self):
|
||||||
case ComparisonOperator.LE:
|
return self.left_value == self.right_value
|
||||||
return self._le()
|
|
||||||
case ComparisonOperator.GT:
|
def ne(self):
|
||||||
return self._gt()
|
return self.left_value != self.right_value
|
||||||
case ComparisonOperator.GE:
|
|
||||||
return self._ge()
|
def lt(self):
|
||||||
case _:
|
return self.left_value < self.right_value
|
||||||
raise ValueError(f"Invalid condition: {self.operator}")
|
|
||||||
|
def le(self):
|
||||||
|
return self.left_value <= self.right_value
|
||||||
|
|
||||||
|
def gt(self):
|
||||||
|
return self.left_value > self.right_value
|
||||||
|
|
||||||
|
def ge(self):
|
||||||
|
return self.left_value >= self.right_value
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanComparisonOperator(ConditionBase):
|
||||||
|
type_limit = bool
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def eq(self):
|
||||||
|
return self.left_value == self.right_value
|
||||||
|
|
||||||
|
def ne(self):
|
||||||
|
return self.left_value != self.right_value
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectComparisonOperator(ConditionBase):
|
||||||
|
type_limit = dict
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def eq(self):
|
||||||
|
return self.left_value == self.right_value
|
||||||
|
|
||||||
|
def ne(self):
|
||||||
|
return self.left_value != self.right_value
|
||||||
|
|
||||||
|
def empty(self):
|
||||||
|
return not self.left_value
|
||||||
|
|
||||||
|
def not_empty(self):
|
||||||
|
return bool(self.left_value)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayComparisonOperator(ConditionBase):
|
||||||
|
type_limit = list
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def empty(self):
|
||||||
|
return not self.left_value
|
||||||
|
|
||||||
|
def not_empty(self):
|
||||||
|
return bool(self.left_value)
|
||||||
|
|
||||||
|
def contains(self):
|
||||||
|
return self.right_value in self.left_value
|
||||||
|
|
||||||
|
def not_contains(self):
|
||||||
|
return self.right_value not in self.left_value
|
||||||
|
|
||||||
|
|
||||||
|
CompareOperatorInstance = Union[
|
||||||
|
StringComparisonOperator,
|
||||||
|
NumberComparisonOperator,
|
||||||
|
BooleanComparisonOperator,
|
||||||
|
ArrayComparisonOperator,
|
||||||
|
ObjectComparisonOperator
|
||||||
|
]
|
||||||
|
CompareOperatorType = Type[CompareOperatorInstance]
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionExpressionResolver:
|
||||||
|
CONDITION_OPERATOR_MAP = {
|
||||||
|
str: StringComparisonOperator,
|
||||||
|
bool: BooleanComparisonOperator,
|
||||||
|
int: NumberComparisonOperator,
|
||||||
|
float: NumberComparisonOperator,
|
||||||
|
list: ArrayComparisonOperator,
|
||||||
|
dict: ObjectComparisonOperator,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_by_value(cls, value) -> CompareOperatorType:
|
||||||
|
for t, op in cls.CONDITION_OPERATOR_MAP.items():
|
||||||
|
if isinstance(value, t):
|
||||||
|
return op
|
||||||
|
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||||
|
|||||||
@@ -4,11 +4,11 @@
|
|||||||
从文件系统加载预定义的工作流模板
|
从文件系统加载预定义的工作流模板
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
class TemplateLoader:
|
class TemplateLoader:
|
||||||
"""工作流模板加载器"""
|
"""工作流模板加载器"""
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -114,7 +115,9 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
# 转换为 VariableSelector
|
# 转换为 VariableSelector
|
||||||
if isinstance(selector, str):
|
if isinstance(selector, str):
|
||||||
selector = VariableSelector.from_string(selector).path
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||||
|
selector = VariableSelector.from_string(variable_literal).path
|
||||||
|
|
||||||
if not selector or len(selector) < 1:
|
if not selector or len(selector) < 1:
|
||||||
raise ValueError("变量选择器不能为空")
|
raise ValueError("变量选择器不能为空")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
@@ -123,7 +124,7 @@ class PromptOptimizerService:
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
current_prompt: str,
|
current_prompt: str,
|
||||||
user_require: str
|
user_require: str
|
||||||
) -> OptimizePromptResult:
|
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||||
"""
|
"""
|
||||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||||
|
|
||||||
@@ -161,6 +162,7 @@ class PromptOptimizerService:
|
|||||||
BusinessException: If the LLM response cannot be parsed as valid JSON
|
BusinessException: If the LLM response cannot be parsed as valid JSON
|
||||||
or does not conform to the expected output format.
|
or does not conform to the expected output format.
|
||||||
"""
|
"""
|
||||||
|
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
|
||||||
model_config = self.get_model_config(tenant_id, model_id)
|
model_config = self.get_model_config(tenant_id, model_id)
|
||||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
@@ -202,17 +204,54 @@ class PromptOptimizerService:
|
|||||||
messages.extend(session_history[:-1]) # last message is current message
|
messages.extend(session_history[:-1]) # last message is current message
|
||||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||||
logger.info(f"Prompt optimization message: {messages}")
|
logger.info(f"Prompt optimization message: {messages}")
|
||||||
optim_resp = await llm.ainvoke(messages)
|
buffer = ""
|
||||||
logger.info(optim_resp.content)
|
prompt_started = False
|
||||||
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
|
prompt_finished = False
|
||||||
prompt = optim_result.get("prompt")
|
idx = 0
|
||||||
desc = optim_result.get("desc")
|
|
||||||
|
|
||||||
return OptimizePromptResult(
|
async for chunk in llm.astream(messages):
|
||||||
prompt=prompt,
|
content = getattr(chunk, "content", chunk)
|
||||||
desc=desc
|
if not content:
|
||||||
|
continue
|
||||||
|
buffer += content
|
||||||
|
cache = buffer[:-20]
|
||||||
|
|
||||||
|
# 尝试找到 "prompt": " 开始位置
|
||||||
|
if prompt_finished:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not prompt_started:
|
||||||
|
m = re.search(r'"prompt"\s*:\s*"', cache)
|
||||||
|
if m:
|
||||||
|
prompt_started = True
|
||||||
|
prompt_index = m.end()
|
||||||
|
idx = prompt_index
|
||||||
|
else:
|
||||||
|
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
|
||||||
|
if m:
|
||||||
|
prompt_index = m.start()
|
||||||
|
prompt_finished = True
|
||||||
|
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||||
|
else:
|
||||||
|
yield {"type": "delta", "content": cache[idx:]}
|
||||||
|
if len(cache) != 0:
|
||||||
|
idx = len(cache)
|
||||||
|
|
||||||
|
# optim_resp = await llm.astream(messages)
|
||||||
|
logger.info(buffer)
|
||||||
|
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||||
|
# prompt = optim_result.get("prompt")
|
||||||
|
desc = optim_result.get("desc")
|
||||||
|
self.create_message(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=RoleType.ASSISTANT,
|
||||||
|
content=desc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -410,7 +410,8 @@ class WorkflowService:
|
|||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
payload: DraftRunRequest,
|
payload: DraftRunRequest,
|
||||||
config: WorkflowConfig
|
config: WorkflowConfig,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
):
|
):
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
@@ -484,7 +485,7 @@ class WorkflowService:
|
|||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id="",
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -530,7 +531,8 @@ class WorkflowService:
|
|||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
payload: DraftRunRequest,
|
payload: DraftRunRequest,
|
||||||
config: WorkflowConfig
|
config: WorkflowConfig,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
):
|
):
|
||||||
"""运行工作流(流式)
|
"""运行工作流(流式)
|
||||||
|
|
||||||
@@ -603,7 +605,7 @@ class WorkflowService:
|
|||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id="",
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id
|
||||||
):
|
):
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
# 直接转发 executor 的事件(已经是正确的格式)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ Rules
|
|||||||
Basic Principles
|
Basic Principles
|
||||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||||
Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels.
|
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
||||||
Language Rule: All label languages must fully match the user input language.
|
Language Rule: All label languages must fully match the user input language.
|
||||||
|
|
||||||
Behavior Guidelines
|
Behavior Guidelines
|
||||||
|
|||||||
Reference in New Issue
Block a user