Merge pull request #25 from SuanmoSuanyangTechnology/fix/workflow

Fix/workflow
This commit is contained in:
Mark
2026-01-06 12:01:42 +08:00
committed by GitHub
9 changed files with 108 additions and 90 deletions

View File

@@ -117,7 +117,7 @@ async def get_prompt_opt(
user_require=data.message user_require=data.message
): ):
# chunk 是 prompt 的增量内容 # chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),

View File

@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
# Set of loop node IDs, used for assigning values in loop nodes # Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list cycle_nodes: list
looping: bool looping: Annotated[bool, lambda x, y: x and y]
# Input variables (passed from configured variables) # Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)

View File

@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
retries -= 1 retries -= 1
if retries > 0: if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000) await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
raise e
except Exception as e:
raise RuntimeError(f"HTTP request node exception: {e}")
else: else:
match self.typed_config.error_handle.method: match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT: case HttpErrorHandle.DEFAULT:
logger.warning( logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result" f"Node {self.node_id}: HTTP request failed, returning default result"
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
f"Node {self.node_id}: HTTP request failed, switching to error handling branch" f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
) )
return "ERROR" return "ERROR"
raise RuntimeError("http request failed")

View File

@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=kb_config.similarity_threshold) score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results # Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
vector_service.reranker = self.get_reranker_model() vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
vector_service.reranker = self.get_reranker_model() vector_service.reranker = self.get_reranker_model()
# TODO其他重排序方式支持
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info( logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
) )
return [chunk.model_dump() for chunk in final_rs] return [chunk.page_content for chunk in final_rs]

View File

@@ -1,5 +1,7 @@
"""LLM 节点配置""" """LLM 节点配置"""
from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
class MessageConfig(BaseModel): class MessageConfig(BaseModel):
"""消息配置""" """消息配置"""
role: str = Field( role: str = Field(
..., ...,
description="消息角色system, user, assistant" description="消息角色system, user, assistant"
) )
content: str = Field( content: str = Field(
..., ...,
description="消息内容,支持模板变量,如:{{ sys.message }}" description="消息内容,支持模板变量,如:{{ sys.message }}"
) )
@field_validator("role") @field_validator("role")
@classmethod @classmethod
def validate_role(cls, v: str) -> str: def validate_role(cls, v: str) -> str:
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
1. 简单模式:使用 prompt 字段 1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐) 2. 消息模式:使用 messages 字段(推荐)
""" """
model_id: str = Field( model_id: str = Field(
..., ...,
description="模型配置 ID" description="模型配置 ID"
) )
context: Any = Field(
default="",
description="上下文"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,
description="提示词模板(简单模式),支持变量引用" description="提示词模板(简单模式),支持变量引用"
) )
# 消息模式(推荐) # 消息模式(推荐)
messages: list[MessageConfig] | None = Field( messages: list[MessageConfig] | None = Field(
default=None, default=None,
description="消息列表(消息模式),支持多轮对话" description="消息列表(消息模式),支持多轮对话"
) )
# 模型参数 # 模型参数
temperature: float | None = Field( temperature: float | None = Field(
default=0.7, default=0.7,
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
le=2.0, le=2.0,
description="温度参数,控制输出的随机性" description="温度参数,控制输出的随机性"
) )
max_tokens: int | None = Field( max_tokens: int | None = Field(
default=1000, default=1000,
ge=1, ge=1,
le=32000, le=32000,
description="最大生成 token 数" description="最大生成 token 数"
) )
top_p: float | None = Field( top_p: float | None = Field(
default=None, default=None,
ge=0.0, ge=0.0,
le=1.0, le=1.0,
description="Top-p 采样参数" description="Top-p 采样参数"
) )
frequency_penalty: float | None = Field( frequency_penalty: float | None = Field(
default=None, default=None,
ge=-2.0, ge=-2.0,
le=2.0, le=2.0,
description="频率惩罚" description="频率惩罚"
) )
presence_penalty: float | None = Field( presence_penalty: float | None = Field(
default=None, default=None,
ge=-2.0, ge=-2.0,
le=2.0, le=2.0,
description="存在惩罚" description="存在惩罚"
) )
# 输出变量定义 # 输出变量定义
output_variables: list[VariableDefinition] = Field( output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [ default_factory=lambda: [
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
], ],
description="输出变量定义(自动生成,通常不需要修改)" description="输出变量定义(自动生成,通常不需要修改)"
) )
@field_validator("messages", "prompt") @field_validator("messages", "prompt")
@classmethod @classmethod
def validate_input_mode(cls, v, info): def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个""" """验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适 # 这个验证在 model_validator 中更合适
return v return v
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"examples": [ "examples": [

View File

@@ -5,15 +5,17 @@ LLM 节点实现
""" """
import logging import logging
import re
from typing import Any from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
- user/human: 用户消息HumanMessage - user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
def _render_context(self, message,state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑) """准备 LLM 实例(公共逻辑)
Args: Args:
@@ -76,15 +85,16 @@ class LLMNode(BaseNode):
# 1. 处理消息格式(优先使用 messages # 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages") messages_config = self.config.get("messages")
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
for msg_config in messages_config: for msg_config in messages_config:
role = msg_config.get("role", "user").lower() role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "") content_template = msg_config.get("content", "")
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state) content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append(SystemMessage(content=content)) messages.append(SystemMessage(content=content))
@@ -95,7 +105,7 @@ class LLMNode(BaseNode):
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content)) messages.append(HumanMessage(content=content))
prompt_or_messages = messages prompt_or_messages = messages
else: else:
# 使用简单的 prompt 格式(向后兼容) # 使用简单的 prompt 格式(向后兼容)
@@ -106,17 +116,17 @@ class LLMNode(BaseNode):
model_id = self.config.get("model_id") model_id = self.config.get("model_id")
if not model_id: if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取 # 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db: with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config: if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0: if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据 # 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0] api_config = config.api_keys[0]
model_name = api_config.model_name model_name = api_config.model_name
@@ -124,26 +134,26 @@ class LLMNode(BaseNode):
api_key = api_config.api_key api_key = api_config.api_key
api_base = api_config.api_base api_base = api_config.api_base
model_type = config.type model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据) # 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True # 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {} extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM( llm = RedBearLLM(
RedBearModelConfig( RedBearModelConfig(
model_name=model_name, model_name=model_name,
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
extra_params=extra_params extra_params=extra_params
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage: async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用 """非流式执行 LLM 调用
@@ -153,10 +163,10 @@ class LLMNode(BaseNode):
Returns: Returns:
LLM 响应消息 LLM 响应消息
""" """
llm, prompt_or_messages = self._prepare_llm(state,True) llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表 # 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages) response = await llm.ainvoke(prompt_or_messages)
# 提取内容 # 提取内容
@@ -164,16 +174,16 @@ class LLMNode(BaseNode):
content = response.content content = response.content
else: else:
content = str(response) content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据 # 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content) return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]: def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)""" """提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state) _, prompt_or_messages = self._prepare_llm(state)
return { return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [ "messages": [
@@ -186,13 +196,13 @@ class LLMNode(BaseNode):
"max_tokens": self.config.get("max_tokens") "max_tokens": self.config.get("max_tokens")
} }
} }
def _extract_output(self, business_result: Any) -> str: def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容""" """从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage): if isinstance(business_result, AIMessage):
return business_result.content return business_result.content
return str(business_result) return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况""" """从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
@@ -204,7 +214,7 @@ class LLMNode(BaseNode):
"total_tokens": usage.get('total_tokens', 0) "total_tokens": usage.get('total_tokens', 0)
} }
return None return None
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用 """流式执行 LLM 调用
@@ -215,26 +225,26 @@ class LLMNode(BaseNode):
文本片段chunk或完成标记 文本片段chunk或完成标记
""" """
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True) llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置 # 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer() writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None) end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix: if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix: if end_prefix:
# 渲染前缀(可能包含其他变量) # 渲染前缀(可能包含其他变量)
try: try:
rendered_prefix = self._render_template(end_prefix, state) rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型) # 提前发送 End 节点的前缀(使用 "message" 类型)
writer({ writer({
"type": "message", # End 相关的内容都是 message 类型 "type": "message", # End 相关的内容都是 message 类型
@@ -246,12 +256,12 @@ class LLMNode(BaseNode):
}) })
except Exception as e: except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}") logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
last_chunk = None last_chunk = None
chunk_count = 0 chunk_count = 0
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages): async for chunk in llm.astream(prompt_or_messages):
# 提取内容 # 提取内容
@@ -259,18 +269,18 @@ class LLMNode(BaseNode):
content = chunk.content content = chunk.content
else: else:
content = str(chunk) content = str(chunk)
# 只有当内容不为空时才处理 # 只有当内容不为空时才处理
if content: if content:
full_response += content full_response += content
last_chunk = chunk last_chunk = chunk
chunk_count += 1 chunk_count += 1
# 流式返回每个文本片段 # 流式返回每个文本片段
yield content yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据 # 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage): if isinstance(last_chunk, AIMessage):
final_message = AIMessage( final_message = AIMessage(
@@ -279,6 +289,6 @@ class LLMNode(BaseNode):
) )
else: else:
final_message = AIMessage(content=full_response) final_message = AIMessage(content=full_response)
# yield 完成标记 # yield 完成标记
yield {"__final__": True, "result": final_message} yield {"__final__": True, "result": final_message}

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
group_id=end_user_id, group_id=end_user_id,
message=self.typed_config.message, message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
return await MemoryAgentService().write_memory( return await MemoryAgentService().write_memory(
group_id=end_user_id, group_id=end_user_id,
message=self.typed_config.message, message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=self.typed_config.config_id,
db=db, db=db,
storage_type="neo4j", storage_type="neo4j",

View File

@@ -87,10 +87,11 @@ class WorkflowValidator:
return graphs return graphs
@classmethod @classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
"""验证工作流配置 """验证工作流配置
Args: Args:
publish: 发布验证标识
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns: Returns:
@@ -114,7 +115,7 @@ class WorkflowValidator:
graphs = cls.get_subgraph(workflow_config) graphs = cls.get_subgraph(workflow_config)
logger.info(graphs) logger.info(graphs)
for graph in graphs: for index, graph in enumerate(graphs):
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
edges = graph.get("edges", []) edges = graph.get("edges", [])
variables = graph.get("variables", []) variables = graph.get("variables", [])
@@ -125,10 +126,11 @@ class WorkflowValidator:
elif len(start_nodes) > 1: elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}") errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个) if index == len(graphs) - 1:
end_nodes = [n for n in nodes if n.get("type") == NodeType.END] # 2. 验证 主图end 节点(至少一个)
if len(end_nodes) == 0: end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
errors.append("工作流必须至少有一个 end 节点") if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性 # 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes] node_ids = [n.get("id") for n in nodes]
@@ -159,15 +161,17 @@ class WorkflowValidator:
elif target not in node_id_set: elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}") errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发) if publish:
if start_nodes and not errors: # 只有在前面验证通过时才检查可达 # 仅在发布时验证所有节点可达
reachable = WorkflowValidator._get_reachable_nodes( # 6. 验证所有节点可达(从 start 节点出发)
start_nodes[0]["id"], if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
edges reachable = WorkflowValidator._get_reachable_nodes(
) start_nodes[0]["id"],
unreachable = node_id_set - reachable edges
if unreachable: )
errors.append(f"以下节点无法从 start 节点到达: {unreachable}") unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点) # 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环 if not errors: # 只有在前面验证通过时才检查循环
@@ -288,7 +292,7 @@ class WorkflowValidator:
(is_valid, errors): 是否有效和错误列表 (is_valid, errors): 是否有效和错误列表
""" """
# 先执行基础验证 # 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config) is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
if not is_valid: if not is_valid:
return False, errors return False, errors

View File

@@ -231,9 +231,9 @@ class PromptOptimizerService:
if m: if m:
prompt_index = m.start() prompt_index = m.start()
prompt_finished = True prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]} yield {"content": buffer[idx:prompt_index]}
else: else:
yield {"type": "delta", "content": cache[idx:]} yield {"content": cache[idx:]}
if len(cache) != 0: if len(cache) != 0:
idx = len(cache) idx = len(cache)
@@ -249,8 +249,8 @@ class PromptOptimizerService:
role=RoleType.ASSISTANT, role=RoleType.ASSISTANT,
content=desc content=desc
) )
variables = self.parser_prompt_variables(optim_result.get("prompt"))
yield {"type": "done", "desc": optim_result.get("desc")} yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod @staticmethod
def parser_prompt_variables(prompt: str): def parser_prompt_variables(prompt: str):