Merge branch 'develop' into refactor/memory_search

# Conflicts:
#	api/app/core/memory/storage_services/search/__init__.py
This commit is contained in:
Eternity
2026-04-20 17:49:29 +08:00
202 changed files with 6621 additions and 1690 deletions

View File

@@ -201,12 +201,15 @@ class VariablePool:
@staticmethod
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
"""If field is given, drill into a dict/object variable's value."""
"""If field is given, drill into a dict/object/array[file] variable's value."""
if field is None:
return struct.instance.get_value()
value = struct.instance.get_value()
# array[file]: extract the field from every element, return a list
if isinstance(value, list):
return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value]
if not isinstance(value, dict):
raise KeyError(f"Variable is not an object, cannot access field '{field}'")
raise KeyError(f"Variable is not an object or array, cannot access field '{field}'")
return value.get(field)
def get_instance(

View File

@@ -28,86 +28,135 @@ class IterationRuntime:
def __init__(
self,
start_id: str,
stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool,
cycle_nodes: list,
cycle_edges: list,
):
"""
Initialize the iteration runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
stream: Whether to run in streaming mode. When True, each iteration
uses graph.astream and emits cycle_item events in real time.
When False, graph.ainvoke is used instead.
node_id: The unique identifier of the iteration node in the workflow.
Also used as the variable namespace for item/index inside
the subgraph (e.g. {{ node_id.item }}).
config: Raw configuration dict for the iteration node, parsed into
IterationNodeConfig. Controls input/output variable selectors,
parallel execution settings, and output flattening.
state: The parent workflow state at the point the iteration node is
entered. Each task receives a copy of this state as its
starting point.
variable_pool: The parent VariablePool containing all variables available
at the time the iteration node executes, including sys.*,
conv.*, and outputs from upstream nodes. Used as the source
for deep-copying into each task's independent child pool.
cycle_nodes: List of node config dicts belonging to this iteration's
subgraph (i.e. nodes whose cycle field equals node_id).
Passed to GraphBuilder when constructing each task's subgraph.
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
Passed to GraphBuilder alongside cycle_nodes.
"""
self.start_id = start_id
self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.cycle_nodes = cycle_nodes
self.cycle_edges = cycle_edges
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
self.output_value = None
self.result: list = []
async def _init_iteration_state(self, item, idx):
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
"""
Initialize a per-iteration copy of the workflow state.
Build an independent compiled subgraph for a single iteration task.
Args:
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Each call creates a brand-new VariablePool by deep-copying the parent pool,
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
execution closure at build time, so the pool and the subgraph always reference
the same object. This is the key design invariant: item/index written into the
pool after build will be visible to all nodes inside the subgraph.
Returns:
A copy of the workflow state with iteration-specific variables set.
graph: The compiled LangGraph subgraph ready for invocation.
child_pool: The VariablePool bound to this subgraph's node closures.
Callers must write item/index into this pool before invoking
the graph, and read output from it after invocation.
start_node_id: The ID of the CYCLE_START node inside the subgraph,
used to set the initial activation signal in workflow state.
"""
loopstate = WorkflowState(
**self.state
from app.core.workflow.engine.graph_builder import GraphBuilder
child_pool = VariablePool()
child_pool.copy(self.variable_pool)
builder = GraphBuilder(
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
stream=self.stream,
variable_pool=child_pool,
cycle=self.node_id,
)
self.child_variable_pool.copy(self.variable_pool)
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
graph = builder.build()
return graph, builder.variable_pool, builder.start_node_id
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
"""
Initialize the workflow state for a single iteration.
Writes the current item and its index into child_pool under the iteration
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
accessible to downstream nodes inside the subgraph via variable selectors.
Also prepares a copy of the parent workflow state with:
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
with the pool values.
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
- activate[start_id] set to True to trigger the CYCLE_START node.
Args:
item: The current element from the input array.
idx: The zero-based index of this element in the input array.
child_pool: The VariablePool bound to this iteration's subgraph.
Must be the same object returned by _build_child_graph.
start_id: The ID of the CYCLE_START node inside the subgraph.
Returns:
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
"""
loopstate = WorkflowState(**self.state)
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
loopstate["looping"] = 1
loopstate["activate"][self.start_id] = True
loopstate["activate"][start_id] = True
return loopstate
def merge_conv_vars(self):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
def _merge_conv_vars(self, child_pool: VariablePool):
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Each task builds its own subgraph so the variable pool closure is independent.
Args:
item: The input element for this iteration.
idx: The index of this iteration.
Returns:
Tuple of (idx, output, result, child_pool, stopped)
"""
graph, child_pool, start_id = self._build_child_graph()
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
if self.stream:
async for event in self.graph.astream(
await self._init_iteration_state(item, idx),
async for event in graph.astream(
init_state,
stream_mode=["debug"],
config=self.checkpoint
config=checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
@@ -117,7 +166,6 @@ class IterationRuntime:
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
@@ -140,17 +188,13 @@ class IterationRuntime:
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
result = self.graph.get_state(config=self.checkpoint).values
result = graph.get_state(config=checkpoint).values
else:
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if result["looping"] == 2:
self.looping = False
return result
result = await graph.ainvoke(init_state)
output = child_pool.get_value(self.output_value)
stopped = result["looping"] == 2
return idx, output, result, child_pool, stopped
def _create_iteration_tasks(self, array_obj, idx):
"""
@@ -196,16 +240,32 @@ class IterationRuntime:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
child_state.extend(await asyncio.gather(*tasks))
self.merge_conv_vars()
batch = await asyncio.gather(*tasks)
# Sort by idx to preserve order, then collect results
batch_sorted = sorted(batch, key=lambda x: x[0])
for _, output, result, child_pool, stopped in batch_sorted:
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
child_state.append(result)
self._merge_conv_vars(child_pool)
if stopped:
self.looping = False
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.run_task(item, idx)
self.merge_conv_vars()
_, output, result, child_pool, stopped = await self.run_task(item, idx)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
self._merge_conv_vars(child_pool)
child_state.append(result)
if stopped:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return {

View File

@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges
def build_graph(self):
def build_graph(self, variable_pool: VariablePool):
"""
Build and compile the internal subgraph for this cycle node.
@@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode):
from app.core.workflow.engine.graph_builder import GraphBuilder
self.child_variable_pool = VariablePool()
self.child_variable_pool.copy(variable_pool)
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
@@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode):
Raises:
RuntimeError: If the node type is unsupported.
"""
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
return await LoopRuntime(
start_id=self.start_node_id,
stream=False,
@@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode):
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
start_id=self.start_node_id,
stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
).run()
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
yield {
"__final__": True,
"result": await LoopRuntime(
@@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode):
yield {
"__final__": True,
"result": await IterationRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
).run()
}
return

View File

@@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel):
@classmethod
def validate_data(cls, v, info):
content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
if content_type == HttpContentType.FROM_DATA and (
not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)):
raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData")
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
raise ValueError("When content_type is JSON, data must be of type str")
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):

View File

@@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode):
))
case HttpContentType.FROM_DATA:
data = {}
content["files"] = {}
files = []
for item in self.typed_config.body.data:
key = self._render_template(item.key, variable_pool)
if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
data[key] = self._render_template(item.value, variable_pool)
elif item.type == "file":
content["files"][self._render_template(item.key, variable_pool)] = (
uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
file_instance = variable_pool.get_instance(item.value)
if isinstance(file_instance, ArrayVariable):
for v in file_instance.value:
if isinstance(v, FileVariable):
files.append((key, (uuid.uuid4().hex, await v.get_content())))
elif isinstance(file_instance, FileVariable):
files.append((key, (uuid.uuid4().hex, await file_instance.get_content())))
content["data"] = data
if files:
content["files"] = files
case HttpContentType.BINARY:
content["files"] = []
file_instence = variable_pool.get_instance(self.typed_config.body.data)

View File

@@ -6,6 +6,30 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class SubVariableConditionItem(BaseModel):
"""A single condition on a file object's field, used inside sub_variable_condition."""
key: str = Field(..., description="Field name of the file object, e.g. type, size, name")
operator: ComparisonOperator = Field(..., description="Comparison operator")
value: Any = Field(default=None, description="Value to compare with, or variable selector when input_type=variable")
input_type: ValueInputType = Field(default=ValueInputType.CONSTANT, description="constant or variable")
@field_validator("input_type", mode="before")
@classmethod
def lower_input_type(cls, v):
if isinstance(v, str):
try:
return ValueInputType(v.lower())
except ValueError:
raise ValueError(f"Invalid input_type: {v}")
return v
class SubVariableCondition(BaseModel):
"""Sub-conditions applied to each file element in an array[file] variable."""
logical_operator: LogicOperator = Field(default=LogicOperator.AND)
conditions: list[SubVariableConditionItem] = Field(default_factory=list)
class ConditionDetail(BaseModel):
operator: ComparisonOperator = Field(
...,
@@ -14,12 +38,12 @@ class ConditionDetail(BaseModel):
left: str = Field(
...,
description="Value to compare against"
description="Variable selector, e.g. {{sys.files}}"
)
right: Any = Field(
default=None,
description="Value to compare with"
description="Value to compare with (unused when sub_variable_condition is set)"
)
input_type: ValueInputType = Field(
@@ -27,6 +51,11 @@ class ConditionDetail(BaseModel):
description="Value input type for comparison"
)
sub_variable_condition: SubVariableCondition | None = Field(
default=None,
description="Sub-conditions for array[file] fields. When set, operator must be contains/not_contains."
)
@field_validator("input_type", mode="before")
@classmethod
def lower_input_type(cls, v):
@@ -39,16 +68,19 @@ class ConditionDetail(BaseModel):
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
"""Configuration for a conditional branch.
logical_operator controls how all expressions are combined (AND/OR).
"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND,
description="Logical operator used to combine multiple condition expressions"
description="Logical operator used to combine all conditions"
)
expressions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
default_factory=list,
description="List of conditions within this branch"
)

View File

@@ -7,7 +7,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance, ArrayFileContainsOperator
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
@@ -90,11 +90,9 @@ class IfElseNode(BaseNode):
list[str]: A list of Python boolean expression strings,
ordered by branch priority.
"""
branch_index = 0
conditions = []
for case_branch in self.typed_config.cases:
branch_index += 1
branch_result = []
for expression in case_branch.expressions:
pattern = r"\{\{\s*(.*?)\s*\}\}"
@@ -103,13 +101,18 @@ class IfElseNode(BaseNode):
left_value = self.get_variable(left_string, variable_pool)
except KeyError:
left_value = None
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
variable_pool,
expression.left,
expression.right,
expression.input_type
)
if expression.sub_variable_condition is not None and isinstance(left_value, list):
evaluator = ArrayFileContainsOperator(left_value, expression.sub_variable_condition, variable_pool)
else:
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
variable_pool,
expression.left,
expression.right,
expression.input_type
)
branch_result.append(self._evaluate(expression.operator, evaluator))
if case_branch.logical_operator == LogicOperator.AND:
conditions.append(all(branch_result))
else:

View File

@@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="Top-p 采样参数"
)
json_output: bool = Field(
default=False,
description="是否以 JSON 格式输出"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,

View File

@@ -22,6 +22,7 @@ from app.db import get_db_context
from app.models import ModelType
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService
from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__)
@@ -126,7 +127,11 @@ class LLMNode(BaseNode):
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
extra_params: dict[str, Any] = {"streaming": stream} if stream else {}
if self.typed_config.temperature is not None:
extra_params["temperature"] = self.typed_config.temperature
if self.typed_config.max_tokens is not None:
extra_params["max_tokens"] = self.typed_config.max_tokens
llm = RedBearLLM(
RedBearModelConfig(
@@ -135,7 +140,9 @@ class LLMNode(BaseNode):
api_key=model_info.api_key,
base_url=model_info.api_base,
extra_params=extra_params,
is_omni=model_info.is_omni
is_omni=model_info.is_omni,
capability=model_info.capability,
json_output=self.typed_config.json_output,
),
type=model_info.model_type
)
@@ -218,6 +225,19 @@ class LLMNode(BaseNode):
rendered = self._render_template(prompt_template, variable_pool)
self.messages = [{"role": "user", "content": rendered}]
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format在 system prompt 中注入
# VOLCANO 模型不支持 response_format同样需要 system prompt 注入
need_json_prompt = self.typed_config.json_output and (
(model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
or model_info.provider.lower() == ModelProvider.VOLCANO
)
if need_json_prompt:
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
if system_msg:
system_msg["content"] += "\n请以JSON格式输出。"
else:
self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"})
return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:

View File

@@ -395,11 +395,73 @@ class NoneObjectComparisonOperator:
return lambda *args, **kwargs: False
class ArrayFileContainsOperator:
"""Handles contains/not_contains on array[file] with sub_variable_condition."""
def __init__(self, left_value: list[dict], sub_variable_condition: Any, pool: VariablePool | None = None):
self.left_value = left_value
self.sub_variable_condition = sub_variable_condition
self.pool = pool
def _resolve_value(self, cond: Any) -> Any:
if cond.input_type == ValueInputType.VARIABLE and self.pool is not None:
pattern = r"\{\{\s*(.*?)\s*\}\}"
selector = re.sub(pattern, r"\1", str(cond.value)).strip()
return self.pool.get_value(selector, default=None, strict=False)
return cond.value
def _match_item(self, file_item: dict) -> bool:
results = []
for cond in self.sub_variable_condition.conditions:
field_val = file_item.get(cond.key)
expected = self._resolve_value(cond)
result = self._eval_sub(field_val, cond.operator.value, expected)
results.append(result)
if self.sub_variable_condition.logical_operator.value == "and":
return all(results)
return any(results)
@staticmethod
def _eval_sub(field_val: Any, op: str, expected: Any) -> bool:
if field_val is None:
return op == "empty"
match op:
case "eq": return str(field_val) == str(expected)
case "ne": return str(field_val) != str(expected)
case "contains": return isinstance(field_val, str) and str(expected) in field_val
case "not_contains": return isinstance(field_val, str) and str(expected) not in field_val
case "in": return field_val in (expected if isinstance(expected, list) else [expected])
case "not_in": return field_val not in (expected if isinstance(expected, list) else [expected])
case "gt": return isinstance(field_val, (int, float)) and field_val > float(expected)
case "ge": return isinstance(field_val, (int, float)) and field_val >= float(expected)
case "lt": return isinstance(field_val, (int, float)) and field_val < float(expected)
case "le": return isinstance(field_val, (int, float)) and field_val <= float(expected)
case "empty": return field_val in (None, "", 0)
case "not_empty": return field_val not in (None, "", 0)
case _: return False
def contains(self) -> bool:
return any(self._match_item(f) for f in self.left_value if isinstance(f, dict))
def not_contains(self) -> bool:
return not self.contains()
def empty(self) -> bool:
return not self.left_value
def not_empty(self) -> bool:
return bool(self.left_value)
def __getattr__(self, name):
return lambda *args, **kwargs: False
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
BooleanComparisonOperator,
ArrayComparisonOperator,
ArrayFileContainsOperator,
ObjectComparisonOperator
]
CompareOperatorType = Type[CompareOperatorInstance]

View File

@@ -15,6 +15,7 @@ from app.services.tool_service import ToolService
logger = logging.getLogger(__name__)
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$")
class ToolNode(BaseNode):
@@ -52,13 +53,21 @@ class ToolNode(BaseNode):
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, variable_pool)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
if isinstance(param_template, str):
pure_match = PURE_VARIABLE_PATTERN.match(param_template)
if pure_match:
# 纯单变量引用直接取原始值,保留 int/bool/float 等类型
rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False)
if rendered_value is None:
rendered_value = self._render_template(param_template, variable_pool)
elif TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, variable_pool)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:
rendered_value = param_template
else:
# 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template
rendered_parameters[param_name] = rendered_value

View File

@@ -84,7 +84,7 @@ class FileVariable(BaseVariable):
total_bytes = 0
chunks = []
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(follow_redirects=True) as client:
async with client.stream("GET", self.value.url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(8192):