Fix/workflow (#92)

* fix(workflow): use loose rendering for end-node variables

* fix(workflow): use int type for memory node config id

* fix(workflow): handle missing environment variable defaults

* fix(workflow): render jinja variables with actual values in non-strict mode

* fix(workflow): support reordering without a rerank model in knowledge base

* fix(workflow): fix typo in key value
This commit is contained in:
Eternity
2026-01-13 15:42:00 +08:00
committed by GitHub
parent 2ba8bb58e0
commit 7a5792ba01
13 changed files with 179 additions and 115 deletions

View File

@@ -14,6 +14,7 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.graph_builder import GraphBuilder from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry # from app.core.tools.registry import ToolRegistry
@@ -78,9 +79,21 @@ class WorkflowExecutor:
var_name = var_def.get("name") var_name = var_def.get("name")
var_default = var_def.get("default") var_default = var_def.get("default")
if var_name: if var_name:
# TODO: 入参类型校验 if var_default:
conversation_vars[var_name] = var_default conversation_vars[var_name] = var_default
else:
var_type = var_def.get("type")
match var_type:
case VariableType.STRING:
conversation_vars[var_name] = ""
case VariableType.NUMBER:
conversation_vars[var_name] = 0
case VariableType.OBJECT:
conversation_vars[var_name] = {}
case VariableType.BOOLEAN:
conversation_vars[var_name] = False
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
conversation_vars[var_name] = []
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构 # 构建分层的变量结构
@@ -362,7 +375,7 @@ class WorkflowExecutor:
inputv = payload.get("input", {}) inputv = payload.get("input", {})
variables = inputv.get("variables", {}) variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {}) variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id") conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id") execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}") logger.info(f"[DEBUG] Node starts execution: {node_name}")
@@ -381,7 +394,7 @@ class WorkflowExecutor:
inputv = result.get("input", {}) inputv = result.get("input", {})
variables = inputv.get("variables", {}) variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {}) variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id") conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id") execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}") logger.info(f"[DEBUG] Node execution completed: {node_name}")

View File

@@ -12,7 +12,6 @@ from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: 子图拆解支持
class GraphBuilder: class GraphBuilder:
def __init__( def __init__(
self, self,

View File

@@ -45,6 +45,7 @@ class AssignerNode(BaseNode):
# Get the value or expression to assign # Get the value or expression to assign
value = assignment.value value = assignment.value
logger.debug(f"left:{variable_selector}, right: {value}")
pattern = r"\{\{\s*(.*?)\s*\}\}" pattern = r"\{\{\s*(.*?)\s*\}\}"
if isinstance(value, str): if isinstance(value, str):
expression = re.match(pattern, value) expression = re.match(pattern, value)
@@ -85,4 +86,3 @@ class AssignerNode(BaseNode):
case _: case _:
raise ValueError(f"Invalid Operator: {assignment.operation}") raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed") logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -259,7 +259,8 @@ class BaseNode(ABC):
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others # Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk" chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") logger.debug(
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping) # Accumulate complete result (for final wrapping)
chunks = [] chunks = []
@@ -386,10 +387,10 @@ class BaseNode(ABC):
yield error_output yield error_output
def _wrap_output( def _wrap_output(
self, self,
business_result: Any, business_result: Any,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将业务结果包装成标准输出格式 """将业务结果包装成标准输出格式
@@ -430,10 +431,10 @@ class BaseNode(ABC):
} }
def _wrap_error( def _wrap_error(
self, self,
error_message: str, error_message: str,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将错误包装成标准输出格式 """将错误包装成标准输出格式
@@ -534,7 +535,7 @@ class BaseNode(ABC):
return edge return edge
return None return None
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str: def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
"""渲染模板 """渲染模板
支持的变量命名空间: 支持的变量命名空间:
@@ -569,7 +570,7 @@ class BaseNode(ABC):
variables=variables, variables=variables,
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
struct=struct strict=strict
) )
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
@@ -628,10 +629,10 @@ class BaseNode(ABC):
return VariablePool(state) return VariablePool(state)
def get_variable( def get_variable(
self, self,
selector: list[str] | str, selector: list[str] | str,
state: WorkflowState, state: WorkflowState,
default: Any = None default: Any = None
) -> Any: ) -> Any:
"""获取变量值(便捷方法) """获取变量值(便捷方法)

View File

@@ -37,7 +37,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state, struct=False) output = self._render_template(output_template, state, strict=False)
else: else:
output = "工作流已完成" output = "工作流已完成"
@@ -156,6 +156,16 @@ class EndNode(BaseNode):
if not output_template: if not output_template:
output = "工作流已完成" output = "工作流已完成"
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": "",
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
@@ -190,7 +200,7 @@ class EndNode(BaseNode):
if upstream_llm_ref_index is None: if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content # No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state) output = self._render_template(output_template, state, strict=False)
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# Send complete content via writer (as a single message chunk) # Send complete content via writer (as a single message chunk)
@@ -246,7 +256,7 @@ class EndNode(BaseNode):
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state) full_output = self._render_template(output_template, state, strict=False)
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")

View File

@@ -38,7 +38,11 @@ class JinjaRenderNode(BaseNode):
context = {} context = {}
for variable in self.typed_config.mapping: for variable in self.typed_config.mapping:
context[variable.name] = self._render_template(variable.value, state) try:
context[variable.name] = self.get_variable(variable.value, state)
except Exception:
logger.info(f"variable not found, var: {variable.value}")
continue
try: try:
res = render.env.from_string(self.typed_config.template).render(**context) res = render.env.from_string(self.typed_config.template).render(**context)

View File

@@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
) )
reranker_id: UUID = Field( reranker_id: UUID = Field(
..., default="",
description="Reranker top k" description="Reranker top k"
) )

View File

@@ -203,19 +203,34 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=kb_config.similarity_threshold) score_threshold=kb_config.similarity_threshold)
# Deduplicate hy brid retrieval results # Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs: if not unique_rs:
continue continue
vector_service.reranker = self.get_reranker_model() if self.typed_config.reranker_id:
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
else:
rs.extend(sorted(
unique_rs,
key=lambda d: d.metadata.get("score", 0),
reverse=True
)[:kb_config.top_k])
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
if not rs: if not rs:
return [] return []
vector_service.reranker = self.get_reranker_model() if self.typed_config.reranker_id:
# TODO其他重排序方式支持 vector_service.reranker = self.get_reranker_model()
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
else:
final_rs = sorted(
rs,
key=lambda d: d.metadata.get("score", 0),
reverse=True
)[:self.typed_config.reranker_top_k]
logger.info( logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
) )

View File

@@ -11,7 +11,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
... ...
) )
config_id: str = Field( config_id: int = Field(
... ...
) )
@@ -26,6 +26,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
... ...
) )
config_id: str = Field( config_id: int = Field(
... ...
) )

View File

@@ -25,7 +25,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
group_id=end_user_id, group_id=end_user_id,
message=self._render_template(self.typed_config.message, state), message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=str(self.typed_config.config_id),
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,
@@ -52,7 +52,7 @@ class MemoryWriteNode(BaseNode):
return await MemoryAgentService().write_memory( return await MemoryAgentService().write_memory(
group_id=end_user_id, group_id=end_user_id,
message=self._render_template(self.typed_config.message, state), message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=str(self.typed_config.config_id),
db=db, db=db,
storage_type="neo4j", storage_type="neo4j",
user_rag_memory_id="" user_rag_memory_id=""

View File

@@ -386,7 +386,10 @@ class ArrayComparisonOperator(ConditionBase):
return self.right_value not in self.left_value return self.right_value not in self.left_value
class NoneObjectComparisonOperator(ConditionBase): class NoneObjectComparisonOperator:
def __init__(self, *arg, **kwargs):
pass
def __getattr__(self, name): def __getattr__(self, name):
return lambda *args, **kwargs: False return lambda *args, **kwargs: False

View File

@@ -5,6 +5,7 @@
""" """
import logging import logging
from collections import defaultdict
from typing import Any from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
return ""
__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error
__getitem__ = __getattr__ = _fail_with_undefined_error
__str__ = __repr__ = lambda self: ""
class TemplateRenderer: class TemplateRenderer:
"""模板渲染器""" """模板渲染器"""
@@ -21,8 +34,9 @@ class TemplateRenderer:
Args: Args:
strict: 是否使用严格模式(未定义变量会抛出异常) strict: 是否使用严格模式(未定义变量会抛出异常)
""" """
self.strict = strict
self.env = Environment( self.env = Environment(
undefined=StrictUndefined if strict else Undefined, undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
) )
@@ -69,12 +83,17 @@ class TemplateRenderer:
# variables 的结构:{"sys": {...}, "conv": {...}} # variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
if self.strict:
context = { context = defaultdict(dict)
"conv": conv_vars, # 会话变量:{{conv.user_name}} context["conv"] = conv_vars
"node": node_outputs, # 节点输出:{{node.node_1.output}} context["node"] = node_outputs
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) context["sys"] = {**(system_vars or {}), **sys_vars}
} else:
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}} # 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文 # 将所有节点输出添加到顶层上下文
@@ -141,12 +160,12 @@ def render_template(
variables: dict[str, Any], variables: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None, system_vars: dict[str, Any] | None = None,
struct: bool = True strict: bool = True
) -> str: ) -> str:
"""渲染模板(便捷函数) """渲染模板(便捷函数)
Args: Args:
struct: 渲染模式 strict: 严格模式
template: 模板字符串 template: 模板字符串
variables: 用户变量 variables: 用户变量
node_outputs: 节点输出 node_outputs: 节点输出
@@ -164,7 +183,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=struct) renderer = TemplateRenderer(strict=strict)
return renderer.render(template, variables, node_outputs, system_vars) return renderer.render(template, variables, node_outputs, system_vars)

View File

@@ -53,7 +53,7 @@ nodes:
type: end type: end
name: 结束 name: 结束
config: config:
output: "{{llm_qa.output}}" output: "{{ llm_qa.output }}"
position: position:
x: 900 x: 900
y: 100 y: 100