feat(workflow):

1. add list operator node for filtering, sorting, limiting, and extracting list items;
2. Increase the session variable to the "file" type
This commit is contained in:
Timebomb2018
2026-04-03 18:57:28 +08:00
parent 32740e8159
commit 38f3455bab
27 changed files with 615 additions and 79 deletions

View File

@@ -1079,6 +1079,14 @@ async def update_workflow_config(
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
if payload.variables:
from app.services.workflow_service import WorkflowService
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
[v.model_dump() for v in payload.variables]
)
# Patch default values back into VariableDefinition objects
for var_def, resolved_def in zip(payload.variables, resolved):
var_def.default = resolved_def.get("default", var_def.default)
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -53,22 +53,24 @@ async def login_for_access_token(
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite:
auth_service.bind_workspace_with_invite(db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id)
auth_service.bind_workspace_with_invite(
db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
except BusinessException as e:
# 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND:
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")

View File

@@ -475,7 +475,7 @@ class LangChainAgent:
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
) -> AsyncGenerator[str | int, None]:
) -> AsyncGenerator[str | int | dict[str, str], None]:
"""执行流式对话
Args:

View File

@@ -25,8 +25,34 @@ class RedBearEmbeddings(Embeddings):
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params)
provider = config.provider.lower()
# Embedding models only need connection params, never LLM-specific ones
# (e.g. enable_thinking, model_kwargs) — build params directly.
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import httpx
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
"max_retries": config.max_retries,
}
elif provider == ModelProvider.DASHSCOPE:
params = {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
}
elif provider == ModelProvider.OLLAMA:
params = {
"model": config.model_name,
"base_url": config.base_url,
}
elif provider == ModelProvider.BEDROCK:
params = RedBearModelFactory.get_model_params(config)
else:
params = RedBearModelFactory.get_model_params(config)
return embedding_class(**params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""

View File

@@ -6,14 +6,28 @@ ChatOpenAI 在解析流式 SSE 时只取 delta.content会丢弃 delta.reasoni
"""
from __future__ import annotations
from typing import Any, Optional
from typing import Any, Optional, Union
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
class VolcanoChatOpenAI(ChatOpenAI):
"""火山引擎 Chat 模型支持深度思考内容reasoning_content的流式透传。"""
"""火山引擎 Chat 模型支持深度思考内容reasoning_content的流式和非流式透传。"""
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
if choices:
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
reasoning = (
getattr(message, "reasoning_content", None)
or (message.get("reasoning_content") if isinstance(message, dict) else None)
)
if reasoning and result.generations:
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
return result
def _convert_chunk_to_generation_chunk(
self,

View File

@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
),
ToolParameter(
name="input_value",

View File

@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig
VariableAggregatorNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
)
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable
)
from app.core.workflow.nodes.list_operator.config import FilterCondition
from app.core.workflow.nodes.enums import (
ValueInputType,
ComparisonOperator,
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
NodeType.TOOL: self.convert_tool_node_config,
NodeType.NOTES: self.convert_notes_config,
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
NodeType.CYCLE_START: lambda x: {},
NodeType.BREAK: lambda x: {},
}
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
"end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY
"not exists": ComparisonOperator.EMPTY,
"in": ComparisonOperator.IN,
"not in": ComparisonOperator.NOT_IN,
}
return operator_map.get(operator, operator)
@@ -771,3 +778,46 @@ class DifyConverter(BaseConverter):
show_author=node_data.get("showAuthor", True)
).model_dump()
return result
def convert_list_operator_node_config(self, node: dict) -> dict:
"""Dify list-operator — convert variable path array to {{ }} selector format."""
node_data = node["data"]
variable_path = node_data.get("variable", [])
input_list = self._process_list_variable_literal(variable_path) or ""
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
# Convert each condition's comparison_operator from Dify format to native
if filter_by.get("conditions"):
converted_conditions = []
for cond in filter_by["conditions"]:
converted_conditions.append({
**cond,
"comparison_operator": self.convert_compare_operator(
cond.get("comparison_operator", "")
)
})
filter_by = {**filter_by, "conditions": converted_conditions}
result = {
"input_list": input_list,
"filter_by": filter_by,
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
}
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
return result
def convert_document_extractor_node_config(self, node: dict) -> dict:
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
Dify document-extractor data fields:
variable_selector: list[str] - file variable path
"""
node_data = node["data"]
file_selector = self._process_list_variable_literal(
node_data.get("variable_selector", [])
) or ""
result = DocExtractorNodeConfig.model_construct(
file_selector=file_selector,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
return result

View File

@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL,
"list-operator": NodeType.LIST_OPERATOR,
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
"": NodeType.NOTES
}

View File

@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
}
@staticmethod

View File

@@ -318,7 +318,7 @@ class VariablePool:
namespace: str,
key: str,
value: Any,
var_type: VariableType,
var_type: VariableType | None,
mut: bool
):
if self.has(f"{namespace}.{key}"):
@@ -493,6 +493,23 @@ class VariablePoolInitializer:
var_value = var_default
else:
var_value = DEFAULT_VALUE(var_type)
# Convert FileInput-format dicts to full FileObject dicts
if var_type == VariableType.FILE:
if not var_value:
continue
var_value = await self._resolve_file_default(var_value)
if not var_value:
continue
elif var_type == VariableType.ARRAY_FILE:
if not var_value:
var_value = []
else:
resolved = []
for item in var_value:
f = await self._resolve_file_default(item)
if f:
resolved.append(f)
var_value = resolved
await variable_pool.new(
namespace="conv",
key=var_name,
@@ -501,6 +518,17 @@ class VariablePoolInitializer:
mut=True
)
@staticmethod
async def _resolve_file_default(file_def: dict) -> dict | None:
"""Accept only already-resolved FileObject dicts (is_file=True).
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
"""
if not isinstance(file_def, dict):
return None
if file_def.get("is_file"):
return file_def
return None
@staticmethod
async def _init_system_vars(
variable_pool: VariablePool,

View File

@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
__all__ = [
# 基础类
@@ -49,5 +51,7 @@ __all__ = [
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig",
"CodeNodeConfig",
"NoteNodeConfig"
"NoteNodeConfig",
"ListOperatorNodeConfig",
"DocExtractorNodeConfig",
]

View File

@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput."""
file_type = f.origin_file_type or ""
# Prefer mime_type for more accurate type detection
if not file_type and f.mime_type:
file_type = f.mime_type
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
if resolved_type != FileType.DOCUMENT:
raise ValueError(
f"Document extractor only supports document files, got type '{f.type}' "
f"(name={f.name or f.file_id or f.url})"
)
return FileInput(
type=FileType.DOCUMENT,
type=resolved_type,
transfer_method=TransferMethod(f.transfer_method),
url=f.url or None,
upload_file_id=f.file_id or None,
file_type=f.origin_file_type or "",
file_type=file_type,
)
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db)
for f in files:
label = f.name or f.url or f.file_id
try:
file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files
@@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode):
chunks.append(text)
except Exception as e:
logger.error(
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
f"Node {self.node_id}: failed to extract file '{label}': {e}",
exc_info=True,
)
chunks.append("")

View File

@@ -24,6 +24,7 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
UNKNOWN = "unknown"
NOTES = "notes"
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
LE = "le"
GT = "gt"
GE = "ge"
IN = "in"
NOT_IN = "not_in"
class LogicOperator(StrEnum):

View File

@@ -0,0 +1,3 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -0,0 +1,44 @@
from typing import Any
from pydantic import BaseModel, Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
value: str | list[str] | bool = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: list[FilterCondition] = Field(default_factory=list)
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: str = "asc" # "asc" | "desc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1" # 1-based index string, e.g. "1" = first
class ListOperatorNodeConfig(BaseNodeConfig):
"""
List Operator node config.
Operation order: filter -> extract -> order -> limit
"""
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
filter_by: FilterBy = Field(default_factory=FilterBy)
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
limit: Limit = Field(default_factory=Limit)
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@@ -0,0 +1,143 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
# File object fields that hold string values
_FILE_STRING_KEYS = {"name", "extension", "mime_type", "url", "transfer_method", "origin_file_type", "file_id"}
_FILE_NUMBER_KEYS = {"size"}
class ListOperatorNode(BaseNode):
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ListOperatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"result": VariableType.ANY,
"first_record": VariableType.ANY,
"last_record": VariableType.ANY,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = ListOperatorNodeConfig(**self.config)
cfg = self.typed_config
# Resolve input variable from path selector
items: list = self.get_variable(cfg.input_list, variable_pool)
if not isinstance(items, list):
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
result = list(items)
# 1. Filter
if cfg.filter_by.enabled and cfg.filter_by.conditions:
for condition in cfg.filter_by.conditions:
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
# 2. Extract (take single item by 1-based serial index)
if cfg.extract_by.enabled:
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
idx = int(serial_str) - 1
if idx < 0 or idx >= len(result):
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
result = [result[idx]]
# 3. Order
if cfg.order_by.enabled and cfg.order_by.key:
reverse = cfg.order_by.value == "desc"
key_fn = self._make_sort_key(cfg.order_by.key)
result = sorted(result, key=key_fn, reverse=reverse)
# 4. Limit (take first N)
if cfg.limit.enabled and cfg.limit.size > 0:
result = result[:cfg.limit.size]
return {
"result": result,
"first_record": result[0] if result else None,
"last_record": result[-1] if result else None,
}
@staticmethod
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
Otherwise return the raw string."""
import re
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
if m:
resolved = variable_pool.get_value(value, default=value, strict=False)
return resolved
return value
@staticmethod
def _make_sort_key(key: str):
def key_fn(item):
if isinstance(item, dict):
return item.get(key) or ""
return item
return key_fn
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
op = cond.comparison_operator
value = cond.value
# Resolve value if it's a variable reference {{ namespace.key }}
if isinstance(value, str):
value = self._resolve_value(value, variable_pool)
# Resolve left value
if isinstance(item, dict):
left = item.get(cond.key)
else:
left = item # primitive array: compare element directly
# Numeric operators
if op == ComparisonOperator.EQ:
return self._safe_num(left) == self._safe_num(value)
if op == ComparisonOperator.NE:
return self._safe_num(left) != self._safe_num(value)
if op == ComparisonOperator.LT:
return self._safe_num(left) < self._safe_num(value)
if op == ComparisonOperator.LE:
return self._safe_num(left) <= self._safe_num(value)
if op == ComparisonOperator.GT:
return self._safe_num(left) > self._safe_num(value)
if op == ComparisonOperator.GE:
return self._safe_num(left) >= self._safe_num(value)
# String / sequence operators
left_str = str(left) if left is not None else ""
if op == ComparisonOperator.CONTAINS:
return str(value) in left_str
if op == ComparisonOperator.NOT_CONTAINS:
return str(value) not in left_str
if op == ComparisonOperator.START_WITH:
return left_str.startswith(str(value))
if op == ComparisonOperator.END_WITH:
return left_str.endswith(str(value))
if op == ComparisonOperator.IN:
return left_str in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.NOT_IN:
return left_str not in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.EMPTY:
return not left
if op == ComparisonOperator.NOT_EMPTY:
return bool(left)
raise ValueError(f"Unsupported operator: {op}")
@staticmethod
def _safe_num(v) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0

View File

@@ -135,8 +135,7 @@ class LLMNode(BaseNode):
api_key=model_info.api_key,
base_url=model_info.api_base,
extra_params=extra_params,
is_omni=model_info.is_omni,
support_thinking="thinking" in (model_info.capability or []),
is_omni=model_info.is_omni
),
type=model_info.model_type
)
@@ -214,9 +213,10 @@ class LLMNode(BaseNode):
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
prompt_template = self.config.get("prompt", "")
self.messages = self._render_template(prompt_template, variable_pool)
rendered = self._render_template(prompt_template, variable_pool)
self.messages = [{"role": "user", "content": rendered}]
return llm

View File

@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode
from app.core.workflow.nodes.list_operator import ListOperatorNode
logger = logging.getLogger(__name__)
@@ -51,7 +52,8 @@ WorkflowNode = Union[
MemoryReadNode,
MemoryWriteNode,
CodeNode,
DocExtractorNode
DocExtractorNode,
ListOperatorNode
]
@@ -83,7 +85,8 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode
}
@classmethod

View File

@@ -118,8 +118,7 @@ class ParameterExtractorNode(BaseNode):
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
support_thinking="thinking" in (capability or []),
is_omni=is_omni
),
type=ModelType(model_type)
)

View File

@@ -71,8 +71,7 @@ class QuestionClassifierNode(BaseNode):
provider=provider,
api_key=api_key,
base_url=base_url,
is_omni=is_omni,
support_thinking="thinking" in (capability or []),
is_omni=is_omni
),
type=ModelType(model_type)
)

View File

@@ -1,7 +1,10 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/3/10 13:36
import mimetypes
import os
import uuid
from typing import Any
from urllib.parse import urlparse, unquote
TRANSFORM_FILE_TYPE = {
'text/plain': 'document/text',
'text/markdown': 'document/markdown',
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
def mime_to_file_type(mime_type):
if mime_type not in ALLOWED_FILE_TYPES:
return None
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
Used as fallback when HTTP request fails.
"""
raw_path = url.split("?")[0]
name = unquote(os.path.basename(urlparse(url).path)) or None
_, ext = os.path.splitext(name or "")
extension = ext.lstrip(".").lower() if ext else None
guessed_mime = mimetypes.guess_type(url)[0]
return {
"type": file_type,
"url": url,
"transfer_method": "remote_url",
"origin_file_type": origin_file_type,
"file_id": None,
"name": name,
"size": None,
"extension": extension,
"mime_type": guessed_mime or origin_file_type,
"is_file": True,
}
async def fetch_remote_file_meta(
url: str,
file_type: str,
origin_file_type: str,
) -> dict[str, Any]:
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
Falls back to URL-only parsing if the HTTP request fails.
"""
import httpx
name = size = mime_type = extension = None
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.head(url, follow_redirects=True)
if resp.status_code != 200:
resp = await client.get(url, follow_redirects=True)
cl = resp.headers.get("Content-Length")
size = int(cl) if cl else None
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
mime_type = ct or origin_file_type
cd = resp.headers.get("Content-Disposition", "")
if "filename=" in cd:
name = cd.split("filename=")[-1].strip('"').strip("'")
if not name:
name = unquote(os.path.basename(urlparse(url).path)) or None
if name:
_, ext = os.path.splitext(name)
extension = ext.lstrip(".").lower() if ext else None
if not extension and mime_type:
ext = mimetypes.guess_extension(mime_type)
extension = ext.lstrip(".").lower() if ext else None
except Exception:
return build_file_object_dict_from_url(url, file_type, origin_file_type)
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="remote_url",
origin_file_type=origin_file_type,
file_id=None,
url=url,
file_name=name,
file_size=size,
file_ext=extension,
content_type=mime_type,
)
def build_file_object_dict_from_meta(
file_type: str,
transfer_method: str,
origin_file_type: str,
file_id: str,
url: str,
file_name: str | None,
file_size: int | None,
file_ext: str | None,
content_type: str | None,
) -> dict[str, Any]:
"""Build a FileObject dict from already-fetched FileMetadata fields."""
ext = (file_ext or "").lstrip(".")
return {
"type": file_type,
"url": url,
"transfer_method": transfer_method,
"origin_file_type": content_type or origin_file_type,
"file_id": file_id,
"name": file_name,
"size": file_size,
"extension": ext.lower() if ext else None,
"mime_type": content_type,
"is_file": True,
}
def resolve_local_file_object_dict(
db,
upload_file_id: str | uuid.UUID,
file_type: str,
origin_file_type: str,
) -> dict[str, Any] | None:
"""Query FileMetadata and build a FileObject dict for a local_file.
Returns None if the file is not found or not completed.
"""
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
try:
fid = uuid.UUID(str(upload_file_id))
except ValueError:
return None
meta = db.query(FileMetadata).filter(
FileMetadata.id == fid,
FileMetadata.status == "completed"
).first()
if not meta:
return None
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="local_file",
origin_file_type=origin_file_type,
file_id=str(fid),
url=url,
file_name=meta.file_name,
file_size=meta.file_size,
file_ext=meta.file_ext,
content_type=meta.content_type,
)

View File

@@ -301,7 +301,7 @@ class WorkflowValidator:
for node in nodes:
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
@@ -311,7 +311,7 @@ class WorkflowValidator:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量

View File

@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
case VariableType.OBJECT:
return {}
case VariableType.FILE:
return None
return {}
case VariableType.ARRAY_STRING:
return []
case VariableType.ARRAY_NUMBER:
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
origin_file_type: str
file_id: str | None
# Extended file metadata
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
content_cache: dict = Field(default_factory=dict)
is_file: bool

View File

@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
type = 'file'
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
return FileObject(**value)
if isinstance(value, FileObject):
return value
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
def get_value(self) -> Any:
return self.value.model_dump()
return self.value.model_dump(exclude={"content_cache"})
async def get_content(self):
total_bytes = 0
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return BooleanVariable(value)
case VariableType.OBJECT:
return DictVariable(value)
case VariableType.FILE:
return FileVariable(value)
case VariableType.ARRAY_STRING:
return make_array(StringVariable, value)
case VariableType.ARRAY_NUMBER:

View File

@@ -640,6 +640,7 @@ class CitationSource(BaseModel):
class DraftRunResponse(BaseModel):
"""试运行响应(非流式)"""
message: str = Field(..., description="AI 回复消息")
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话")
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
@@ -647,6 +648,12 @@ class DraftRunResponse(BaseModel):
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
if not data.get("reasoning_content"):
data.pop("reasoning_content", None)
return data
class OpeningResponse(BaseModel):
"""应用开场白响应"""

View File

@@ -16,7 +16,6 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,22 +452,70 @@ class WorkflowService:
"success_rate": completed / total if total > 0 else 0
}
async def _resolve_variables_file_defaults(
self,
variables: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert FileInput-format defaults in workflow variables to full FileObject dicts."""
from app.core.workflow.utils.file_processor import (
resolve_local_file_object_dict,
fetch_remote_file_meta,
)
async def _resolve_one(item: dict) -> dict | None:
if not isinstance(item, dict) or item.get("is_file"):
return item
transfer_method = item.get("transfer_method", "remote_url")
file_type = item.get("type", "document")
origin_file_type = item.get("file_type") or file_type
if transfer_method == "remote_url":
url = item.get("url", "")
return await fetch_remote_file_meta(url, file_type, origin_file_type) if url else None
else:
return resolve_local_file_object_dict(self.db, item.get("upload_file_id"), file_type, origin_file_type)
result = []
for var_def in variables:
var_type = var_def.get("type", "")
default = var_def.get("default")
if var_type == "file" and isinstance(default, dict) and not default.get("is_file"):
var_def = {**var_def, "default": await _resolve_one(default)}
elif var_type == "array[file]" and isinstance(default, list):
resolved = []
for item in default:
r = await _resolve_one(item)
if r is not None:
resolved.append(r)
var_def = {**var_def, "default": resolved}
result.append(var_def)
return result
async def _handle_file_input(self, files: list[FileInput]):
if not files:
return []
from app.core.workflow.utils.file_processor import (
resolve_local_file_object_dict,
build_file_object_dict_from_meta,
fetch_remote_file_meta,
)
files_struct = []
for file in files:
files_struct.append(
FileObject(
type=file.type,
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method,
file_id=str(file.upload_file_id) if file.upload_file_id else None,
origin_file_type=file.file_type,
is_file=True
).model_dump()
)
url = await self.multimodal_service.get_file_url(file)
file_type = str(file.type)
origin_file_type = file.file_type or file_type
if file.transfer_method.value == "local_file" and file.upload_file_id:
fo = resolve_local_file_object_dict(self.db, file.upload_file_id, file_type, origin_file_type)
files_struct.append(fo or build_file_object_dict_from_meta(
file_type=file_type, transfer_method="local_file",
origin_file_type=origin_file_type,
file_id=str(file.upload_file_id), url=url,
file_name=None, file_size=None, file_ext=None, content_type=None,
))
else:
files_struct.append(await fetch_remote_file_meta(url, file_type, origin_file_type))
return files_struct
@staticmethod

View File

@@ -480,21 +480,21 @@ def create_workspace_invite(
try:
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
if settings.ENABLE_SINGLE_WORKSPACE:
# 检查被邀请用户是否已经在工作空间中
from app.repositories import user_repository
invited_user = user_repository.get_user_by_email(db, invite_data.email)
# if settings.ENABLE_SINGLE_WORKSPACE:
# 检查被邀请用户是否已经在工作空间中
from app.repositories import user_repository
invited_user = user_repository.get_user_by_email(db, invite_data.email)
if invited_user:
# 用户存在,检查是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
db=db,
user_id=invited_user.id,
workspace_id=workspace_id
)
if existing_member:
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
if invited_user:
# 用户存在,检查是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
db=db,
user_id=invited_user.id,
workspace_id=workspace_id
)
if existing_member:
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
# 检查是否已有待处理的邀请
invite_repo = WorkspaceInviteRepository(db)