Merge pull request #873 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

feat(workflow and app)
This commit is contained in:
山程漫悟
2026-04-13 19:05:10 +08:00
committed by GitHub
8 changed files with 309 additions and 100 deletions

View File

@@ -1250,9 +1250,11 @@ async def export_app(
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
app_id: Optional[str] = Form(None),
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
@@ -1263,13 +1265,15 @@ async def import_app(
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
target_app_id = uuid.UUID(app_id) if app_id else None
result_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
app_id=target_app_id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -28,86 +28,135 @@ class IterationRuntime:
def __init__(
self,
start_id: str,
stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool,
cycle_nodes: list,
cycle_edges: list,
):
"""
Initialize the iteration runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
stream: Whether to run in streaming mode. When True, each iteration
uses graph.astream and emits cycle_item events in real time.
When False, graph.ainvoke is used instead.
node_id: The unique identifier of the iteration node in the workflow.
Also used as the variable namespace for item/index inside
the subgraph (e.g. {{ node_id.item }}).
config: Raw configuration dict for the iteration node, parsed into
IterationNodeConfig. Controls input/output variable selectors,
parallel execution settings, and output flattening.
state: The parent workflow state at the point the iteration node is
entered. Each task receives a copy of this state as its
starting point.
variable_pool: The parent VariablePool containing all variables available
at the time the iteration node executes, including sys.*,
conv.*, and outputs from upstream nodes. Used as the source
for deep-copying into each task's independent child pool.
cycle_nodes: List of node config dicts belonging to this iteration's
subgraph (i.e. nodes whose cycle field equals node_id).
Passed to GraphBuilder when constructing each task's subgraph.
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
Passed to GraphBuilder alongside cycle_nodes.
"""
self.start_id = start_id
self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.cycle_nodes = cycle_nodes
self.cycle_edges = cycle_edges
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
self.output_value = None
self.result: list = []
async def _init_iteration_state(self, item, idx):
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
"""
Initialize a per-iteration copy of the workflow state.
Build an independent compiled subgraph for a single iteration task.
Args:
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Each call creates a brand-new VariablePool by deep-copying the parent pool,
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
execution closure at build time, so the pool and the subgraph always reference
the same object. This is the key design invariant: item/index written into the
pool after build will be visible to all nodes inside the subgraph.
Returns:
A copy of the workflow state with iteration-specific variables set.
graph: The compiled LangGraph subgraph ready for invocation.
child_pool: The VariablePool bound to this subgraph's node closures.
Callers must write item/index into this pool before invoking
the graph, and read output from it after invocation.
start_node_id: The ID of the CYCLE_START node inside the subgraph,
used to set the initial activation signal in workflow state.
"""
loopstate = WorkflowState(
**self.state
from app.core.workflow.engine.graph_builder import GraphBuilder
child_pool = VariablePool()
child_pool.copy(self.variable_pool)
builder = GraphBuilder(
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
stream=self.stream,
variable_pool=child_pool,
cycle=self.node_id,
)
self.child_variable_pool.copy(self.variable_pool)
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
graph = builder.build()
return graph, builder.variable_pool, builder.start_node_id
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
"""
Initialize the workflow state for a single iteration.
Writes the current item and its index into child_pool under the iteration
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
accessible to downstream nodes inside the subgraph via variable selectors.
Also prepares a copy of the parent workflow state with:
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
with the pool values.
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
- activate[start_id] set to True to trigger the CYCLE_START node.
Args:
item: The current element from the input array.
idx: The zero-based index of this element in the input array.
child_pool: The VariablePool bound to this iteration's subgraph.
Must be the same object returned by _build_child_graph.
start_id: The ID of the CYCLE_START node inside the subgraph.
Returns:
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
"""
loopstate = WorkflowState(**self.state)
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
loopstate["looping"] = 1
loopstate["activate"][self.start_id] = True
loopstate["activate"][start_id] = True
return loopstate
def merge_conv_vars(self):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
def _merge_conv_vars(self, child_pool: VariablePool):
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Each task builds its own subgraph so the variable pool closure is independent.
Args:
item: The input element for this iteration.
idx: The index of this iteration.
Returns:
Tuple of (idx, output, result, child_pool, stopped)
"""
graph, child_pool, start_id = self._build_child_graph()
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
if self.stream:
async for event in self.graph.astream(
await self._init_iteration_state(item, idx),
async for event in graph.astream(
init_state,
stream_mode=["debug"],
config=self.checkpoint
config=checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
@@ -117,7 +166,6 @@ class IterationRuntime:
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
@@ -140,17 +188,13 @@ class IterationRuntime:
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
result = self.graph.get_state(config=self.checkpoint).values
result = graph.get_state(config=checkpoint).values
else:
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if result["looping"] == 2:
self.looping = False
return result
result = await graph.ainvoke(init_state)
output = child_pool.get_value(self.output_value)
stopped = result["looping"] == 2
return idx, output, result, child_pool, stopped
def _create_iteration_tasks(self, array_obj, idx):
"""
@@ -196,16 +240,32 @@ class IterationRuntime:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
child_state.extend(await asyncio.gather(*tasks))
self.merge_conv_vars()
batch = await asyncio.gather(*tasks)
# Sort by idx to preserve order, then collect results
batch_sorted = sorted(batch, key=lambda x: x[0])
for _, output, result, child_pool, stopped in batch_sorted:
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
child_state.append(result)
self._merge_conv_vars(child_pool)
if stopped:
self.looping = False
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.run_task(item, idx)
self.merge_conv_vars()
_, output, result, child_pool, stopped = await self.run_task(item, idx)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
self._merge_conv_vars(child_pool)
child_state.append(result)
if stopped:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return {

View File

@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges
def build_graph(self):
def build_graph(self, variable_pool: VariablePool):
"""
Build and compile the internal subgraph for this cycle node.
@@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode):
from app.core.workflow.engine.graph_builder import GraphBuilder
self.child_variable_pool = VariablePool()
self.child_variable_pool.copy(variable_pool)
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
@@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode):
Raises:
RuntimeError: If the node type is unsupported.
"""
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
return await LoopRuntime(
start_id=self.start_node_id,
stream=False,
@@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode):
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
start_id=self.start_node_id,
stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
).run()
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
yield {
"__final__": True,
"result": await LoopRuntime(
@@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode):
yield {
"__final__": True,
"result": await IterationRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
).run()
}
return

View File

@@ -44,6 +44,8 @@ class FileInput(BaseModel):
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件IDlocal_file时必填")
url: Optional[str] = Field(None, description="远程URLremote_url时必填")
file_type: Optional[str] = Field(None, description="具体文件格式如image/jpg、audio/wav、document/docx、video/mp4")
name: Optional[str] = Field(None, description="文件名")
size: Optional[int] = Field(None, description="文件大小(字节)")
_content = None

View File

@@ -26,6 +26,7 @@ from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService
from app.models.file_metadata_model import FileMetadata
logger = get_business_logger()
@@ -218,11 +219,29 @@ class AppChatService:
"reasoning_content": result.get("reasoning_content")
}
if files:
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
# url = await MultimodalService(self.db).get_file_url(f)
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url
"url": f.url,
"name": name,
"size": size,
"file_type": f.file_type,
})
if processed_files:
@@ -509,10 +528,29 @@ class AppChatService:
}
if files:
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url
"url": f.url,
"name": name,
"size": size,
"file_type": f.file_type,
})
if processed_files:
human_meta["history_files"] = {

View File

@@ -229,8 +229,11 @@ class AppDslService:
workspace_id: uuid.UUID,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
app_id: Optional[uuid.UUID] = None,
) -> tuple[App, list[str]]:
"""解析 DSL创建应用配置,返回 (new_app, warnings)"""
"""解析 DSL创建或覆盖应用配置,返回 (app, warnings)
app_id 不为空时:校验类型一致后覆盖配置;为空时创建新应用。
"""
app_meta = dsl.get("app", {})
app_type = app_meta.get("type")
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
@@ -239,6 +242,9 @@ class AppDslService:
warnings: list[str] = []
now = datetime.datetime.now()
if app_id is not None:
return self._overwrite_dsl(dsl, app_id, app_type, workspace_id, tenant_id, warnings, now)
new_app = App(
id=uuid.uuid4(),
workspace_id=workspace_id,
@@ -258,11 +264,57 @@ class AppDslService:
self.db.add(new_app)
self.db.flush()
self._write_config(new_app.id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=True)
self.db.commit()
self.db.refresh(new_app)
return new_app, warnings
def _overwrite_dsl(
self,
dsl: dict,
app_id: uuid.UUID,
app_type: str,
workspace_id: uuid.UUID,
tenant_id: uuid.UUID,
warnings: list,
now: datetime.datetime,
) -> tuple[App, list[str]]:
"""覆盖已有应用的配置,类型不一致时抛出异常"""
app = self.db.query(App).filter(
App.id == app_id,
App.workspace_id == workspace_id,
App.is_active.is_(True)
).first()
if not app:
raise ResourceNotFoundException("应用", str(app_id))
if app.type != app_type:
raise BusinessException(
f"YAML 类型 '{app_type}' 与应用类型 '{app.type}' 不一致,无法导入",
BizCode.BAD_REQUEST
)
self._write_config(app_id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=False)
self.db.commit()
self.db.refresh(app)
return app, warnings
def _write_config(
self,
app_id: uuid.UUID,
app_type: str,
dsl: dict,
workspace_id: uuid.UUID,
tenant_id: uuid.UUID,
warnings: list,
now: datetime.datetime,
create: bool,
) -> None:
"""写入(新建或覆盖)应用配置"""
if app_type == AppType.AGENT:
cfg = dsl.get("agent_config") or {}
self.db.add(AgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
fields = dict(
system_prompt=cfg.get("system_prompt"),
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
@@ -272,16 +324,21 @@ class AppDslService:
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
features=cfg.get("features", {}),
is_active=True,
created_at=now,
updated_at=now,
))
)
if create:
self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
else:
existing = self.db.query(AgentConfig).filter(AgentConfig.app_id == app_id).first()
if existing:
for k, v in fields.items():
setattr(existing, k, v)
else:
self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
elif app_type == AppType.MULTI_AGENT:
cfg = dsl.get("multi_agent_config") or {}
self.db.add(MultiAgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
fields = dict(
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
master_agent_name=cfg.get("master_agent_name"),
model_parameters=cfg.get("model_parameters"),
@@ -291,10 +348,17 @@ class AppDslService:
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
execution_config=cfg.get("execution_config", {}),
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
is_active=True,
created_at=now,
updated_at=now,
))
)
if create:
self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
else:
existing = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app_id).first()
if existing:
for k, v in fields.items():
setattr(existing, k, v)
else:
self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
elif app_type == AppType.WORKFLOW:
adapter = MemoryBearAdapter(dsl)
@@ -306,20 +370,39 @@ class AppDslService:
for w in result.warnings:
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
wf = dsl.get("workflow") or {}
WorkflowService(self.db).create_workflow_config(
app_id=new_app.id,
nodes=[n.model_dump() for n in result.nodes],
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
features=wf.get("features", {}),
triggers=wf.get("triggers", []),
validate=False,
)
self.db.commit()
self.db.refresh(new_app)
return new_app, warnings
wf_service = WorkflowService(self.db)
if create:
wf_service.create_workflow_config(
app_id=app_id,
nodes=[n.model_dump() for n in result.nodes],
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
features=wf.get("features", {}),
triggers=wf.get("triggers", []),
validate=False,
)
else:
existing = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app_id).first()
if existing:
existing.nodes = [n.model_dump() for n in result.nodes]
existing.edges = [e.model_dump() for e in result.edges]
existing.variables = [v.model_dump() for v in result.variables]
existing.execution_config = wf.get("execution_config", {})
existing.features = wf.get("features", {})
existing.triggers = wf.get("triggers", [])
existing.updated_at = now
else:
wf_service.create_workflow_config(
app_id=app_id,
nodes=[n.model_dump() for n in result.nodes],
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
features=wf.get("features", {}),
triggers=wf.get("triggers", []),
validate=False,
)
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
"""生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用"""

View File

@@ -1299,10 +1299,30 @@ class AgentRunService:
"history_files": {}
}
if files:
from app.models.file_metadata_model import FileMetadata
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url
"url": f.url,
"file_type": f.file_type,
"name": name,
"size": size
})
# 保存 history_files包含 provider 和 is_omni 信息

View File

@@ -957,7 +957,10 @@ class WorkflowService:
for file in message["content"]:
human_meta["files"].append({
"type": file.get("type"),
"url": file.get("url")
"url": file.get("url"),
"file_type": file.get("origin_file_type"),
"name": file.get("name"),
"size": file.get("size")
})
if message["role"] == "assistant":
assistant_message = message["content"]