Fix/workflow (#92)
* fix(workflow): use loose rendering for end-node variables * fix(workflow): use int type for memory node config id * fix(workflow): handle missing environment variable defaults * fix(workflow): render jinja variables with actual values in non-strict mode * fix(workflow): support reordering without a rerank model in knowledge base * fix(workflow): fix typo in key value
This commit is contained in:
@@ -14,6 +14,7 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
@@ -78,9 +79,21 @@ class WorkflowExecutor:
|
||||
var_name = var_def.get("name")
|
||||
var_default = var_def.get("default")
|
||||
if var_name:
|
||||
# TODO: 入参类型校验
|
||||
if var_default:
|
||||
conversation_vars[var_name] = var_default
|
||||
|
||||
else:
|
||||
var_type = var_def.get("type")
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
conversation_vars[var_name] = ""
|
||||
case VariableType.NUMBER:
|
||||
conversation_vars[var_name] = 0
|
||||
case VariableType.OBJECT:
|
||||
conversation_vars[var_name] = {}
|
||||
case VariableType.BOOLEAN:
|
||||
conversation_vars[var_name] = False
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
conversation_vars[var_name] = []
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
# 构建分层的变量结构
|
||||
@@ -362,7 +375,7 @@ class WorkflowExecutor:
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||
|
||||
@@ -381,7 +394,7 @@ class WorkflowExecutor:
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.workflow.nodes.enums import NodeType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: 子图拆解支持
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -45,6 +45,7 @@ class AssignerNode(BaseNode):
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = assignment.value
|
||||
logger.debug(f"left:{variable_selector}, right: {value}")
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
if isinstance(value, str):
|
||||
expression = re.match(pattern, value)
|
||||
@@ -85,4 +86,3 @@ class AssignerNode(BaseNode):
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
|
||||
@@ -259,7 +259,8 @@ class BaseNode(ABC):
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
logger.debug(
|
||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
@@ -534,7 +535,7 @@ class BaseNode(ABC):
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str:
|
||||
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
@@ -569,7 +570,7 @@ class BaseNode(ABC):
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
struct=struct
|
||||
strict=strict
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
|
||||
@@ -37,7 +37,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, struct=False)
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
@@ -156,6 +156,16 @@ class EndNode(BaseNode):
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": "",
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
@@ -190,7 +200,7 @@ class EndNode(BaseNode):
|
||||
|
||||
if upstream_llm_ref_index is None:
|
||||
# No reference to direct upstream LLM node, output complete template content
|
||||
output = self._render_template(output_template, state)
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
# Send complete content via writer (as a single message chunk)
|
||||
@@ -246,7 +256,7 @@ class EndNode(BaseNode):
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
|
||||
@@ -38,7 +38,11 @@ class JinjaRenderNode(BaseNode):
|
||||
|
||||
context = {}
|
||||
for variable in self.typed_config.mapping:
|
||||
context[variable.name] = self._render_template(variable.value, state)
|
||||
try:
|
||||
context[variable.name] = self.get_variable(variable.value, state)
|
||||
except Exception:
|
||||
logger.info(f"variable not found, var: {variable.value}")
|
||||
continue
|
||||
|
||||
try:
|
||||
res = render.env.from_string(self.typed_config.template).render(**context)
|
||||
|
||||
@@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
...,
|
||||
default="",
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
|
||||
@@ -203,19 +203,34 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:self.typed_config.reranker_top_k]
|
||||
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: str = Field(
|
||||
config_id: int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -26,6 +26,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: str = Field(
|
||||
config_id: int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ class MemoryReadNode(BaseNode):
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
@@ -52,7 +52,7 @@ class MemoryWriteNode(BaseNode):
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
|
||||
@@ -386,7 +386,10 @@ class ArrayComparisonOperator(ConditionBase):
|
||||
return self.right_value not in self.left_value
|
||||
|
||||
|
||||
class NoneObjectComparisonOperator(ConditionBase):
|
||||
class NoneObjectComparisonOperator:
|
||||
def __init__(self, *arg, **kwargs):
|
||||
pass
|
||||
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafeUndefined(Undefined):
|
||||
"""访问未定义属性不会报错,返回空字符串"""
|
||||
__slots__ = ()
|
||||
|
||||
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error
|
||||
__getitem__ = __getattr__ = _fail_with_undefined_error
|
||||
__str__ = __repr__ = lambda self: ""
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
@@ -21,8 +34,9 @@ class TemplateRenderer:
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
"""
|
||||
self.strict = strict
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else Undefined,
|
||||
undefined=StrictUndefined if strict else SafeUndefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
@@ -69,7 +83,12 @@ class TemplateRenderer:
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
|
||||
if self.strict:
|
||||
context = defaultdict(dict)
|
||||
context["conv"] = conv_vars
|
||||
context["node"] = node_outputs
|
||||
context["sys"] = {**(system_vars or {}), **sys_vars}
|
||||
else:
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
@@ -141,12 +160,12 @@ def render_template(
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None,
|
||||
struct: bool = True
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
Args:
|
||||
struct: 渲染模式
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
@@ -164,7 +183,7 @@ def render_template(
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
renderer = TemplateRenderer(strict=struct)
|
||||
renderer = TemplateRenderer(strict=strict)
|
||||
return renderer.render(template, variables, node_outputs, system_vars)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user