feat(workflow):
1. add list operator node for filtering, sorting, limiting, and extracting list items; 2. Increase the session variable to the "file" type
This commit is contained in:
@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
VariableAggregatorNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
)
|
||||
from app.core.workflow.nodes.list_operator.config import FilterCondition
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
|
||||
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
"not exists": ComparisonOperator.EMPTY,
|
||||
"in": ComparisonOperator.IN,
|
||||
"not in": ComparisonOperator.NOT_IN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -771,3 +778,46 @@ class DifyConverter(BaseConverter):
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
def convert_list_operator_node_config(self, node: dict) -> dict:
|
||||
"""Dify list-operator — convert variable path array to {{ }} selector format."""
|
||||
node_data = node["data"]
|
||||
variable_path = node_data.get("variable", [])
|
||||
input_list = self._process_list_variable_literal(variable_path) or ""
|
||||
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
|
||||
# Convert each condition's comparison_operator from Dify format to native
|
||||
if filter_by.get("conditions"):
|
||||
converted_conditions = []
|
||||
for cond in filter_by["conditions"]:
|
||||
converted_conditions.append({
|
||||
**cond,
|
||||
"comparison_operator": self.convert_compare_operator(
|
||||
cond.get("comparison_operator", "")
|
||||
)
|
||||
})
|
||||
filter_by = {**filter_by, "conditions": converted_conditions}
|
||||
result = {
|
||||
"input_list": input_list,
|
||||
"filter_by": filter_by,
|
||||
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
|
||||
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
|
||||
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
|
||||
}
|
||||
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_document_extractor_node_config(self, node: dict) -> dict:
|
||||
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
|
||||
|
||||
Dify document-extractor data fields:
|
||||
variable_selector: list[str] - file variable path
|
||||
"""
|
||||
node_data = node["data"]
|
||||
file_selector = self._process_list_variable_literal(
|
||||
node_data.get("variable_selector", [])
|
||||
) or ""
|
||||
result = DocExtractorNodeConfig.model_construct(
|
||||
file_selector=file_selector,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"list-operator": NodeType.LIST_OPERATOR,
|
||||
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -318,7 +318,7 @@ class VariablePool:
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
var_type: VariableType | None,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
@@ -493,6 +493,23 @@ class VariablePoolInitializer:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
# Convert FileInput-format dicts to full FileObject dicts
|
||||
if var_type == VariableType.FILE:
|
||||
if not var_value:
|
||||
continue
|
||||
var_value = await self._resolve_file_default(var_value)
|
||||
if not var_value:
|
||||
continue
|
||||
elif var_type == VariableType.ARRAY_FILE:
|
||||
if not var_value:
|
||||
var_value = []
|
||||
else:
|
||||
resolved = []
|
||||
for item in var_value:
|
||||
f = await self._resolve_file_default(item)
|
||||
if f:
|
||||
resolved.append(f)
|
||||
var_value = resolved
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
@@ -501,6 +518,17 @@ class VariablePoolInitializer:
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_default(file_def: dict) -> dict | None:
|
||||
"""Accept only already-resolved FileObject dicts (is_file=True).
|
||||
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
|
||||
"""
|
||||
if not isinstance(file_def, dict):
|
||||
return None
|
||||
if file_def.get("is_file"):
|
||||
return file_def
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
|
||||
@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -49,5 +51,7 @@ __all__ = [
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
"NoteNodeConfig",
|
||||
"ListOperatorNodeConfig",
|
||||
"DocExtractorNodeConfig",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||
"""Convert workflow FileObject to multimodal FileInput."""
|
||||
file_type = f.origin_file_type or ""
|
||||
# Prefer mime_type for more accurate type detection
|
||||
if not file_type and f.mime_type:
|
||||
file_type = f.mime_type
|
||||
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
|
||||
if resolved_type != FileType.DOCUMENT:
|
||||
raise ValueError(
|
||||
f"Document extractor only supports document files, got type '{f.type}' "
|
||||
f"(name={f.name or f.file_id or f.url})"
|
||||
)
|
||||
return FileInput(
|
||||
type=FileType.DOCUMENT,
|
||||
type=resolved_type,
|
||||
transfer_method=TransferMethod(f.transfer_method),
|
||||
url=f.url or None,
|
||||
upload_file_id=f.file_id or None,
|
||||
file_type=f.origin_file_type or "",
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
svc = MultimodalService(db)
|
||||
for f in files:
|
||||
label = f.name or f.url or f.file_id
|
||||
try:
|
||||
file_input = _file_object_to_file_input(f)
|
||||
# Ensure URL is populated for local files
|
||||
@@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode):
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||
f"Node {self.node_id}: failed to extract file '{label}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chunks.append("")
|
||||
|
||||
@@ -24,6 +24,7 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
IN = "in"
|
||||
NOT_IN = "not_in"
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
|
||||
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
|
||||
value: str | list[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: list[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: str = "asc" # "asc" | "desc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ExtractConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
serial: str = "1" # 1-based index string, e.g. "1" = first
|
||||
|
||||
|
||||
class ListOperatorNodeConfig(BaseNodeConfig):
|
||||
"""
|
||||
List Operator node config.
|
||||
Operation order: filter -> extract -> order -> limit
|
||||
"""
|
||||
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
|
||||
filter_by: FilterBy = Field(default_factory=FilterBy)
|
||||
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
|
||||
limit: Limit = Field(default_factory=Limit)
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# File object fields that hold string values
|
||||
_FILE_STRING_KEYS = {"name", "extension", "mime_type", "url", "transfer_method", "origin_file_type", "file_id"}
|
||||
_FILE_NUMBER_KEYS = {"size"}
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ListOperatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"result": VariableType.ANY,
|
||||
"first_record": VariableType.ANY,
|
||||
"last_record": VariableType.ANY,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = ListOperatorNodeConfig(**self.config)
|
||||
cfg = self.typed_config
|
||||
|
||||
# Resolve input variable from path selector
|
||||
items: list = self.get_variable(cfg.input_list, variable_pool)
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
|
||||
|
||||
result = list(items)
|
||||
|
||||
# 1. Filter
|
||||
if cfg.filter_by.enabled and cfg.filter_by.conditions:
|
||||
for condition in cfg.filter_by.conditions:
|
||||
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
|
||||
|
||||
# 2. Extract (take single item by 1-based serial index)
|
||||
if cfg.extract_by.enabled:
|
||||
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
|
||||
idx = int(serial_str) - 1
|
||||
if idx < 0 or idx >= len(result):
|
||||
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
|
||||
result = [result[idx]]
|
||||
|
||||
# 3. Order
|
||||
if cfg.order_by.enabled and cfg.order_by.key:
|
||||
reverse = cfg.order_by.value == "desc"
|
||||
key_fn = self._make_sort_key(cfg.order_by.key)
|
||||
result = sorted(result, key=key_fn, reverse=reverse)
|
||||
|
||||
# 4. Limit (take first N)
|
||||
if cfg.limit.enabled and cfg.limit.size > 0:
|
||||
result = result[:cfg.limit.size]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"first_record": result[0] if result else None,
|
||||
"last_record": result[-1] if result else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
|
||||
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
|
||||
Otherwise return the raw string."""
|
||||
import re
|
||||
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
|
||||
if m:
|
||||
resolved = variable_pool.get_value(value, default=value, strict=False)
|
||||
return resolved
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _make_sort_key(key: str):
|
||||
def key_fn(item):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key) or ""
|
||||
return item
|
||||
return key_fn
|
||||
|
||||
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
|
||||
op = cond.comparison_operator
|
||||
value = cond.value
|
||||
|
||||
# Resolve value if it's a variable reference {{ namespace.key }}
|
||||
if isinstance(value, str):
|
||||
value = self._resolve_value(value, variable_pool)
|
||||
|
||||
# Resolve left value
|
||||
if isinstance(item, dict):
|
||||
left = item.get(cond.key)
|
||||
else:
|
||||
left = item # primitive array: compare element directly
|
||||
|
||||
# Numeric operators
|
||||
if op == ComparisonOperator.EQ:
|
||||
return self._safe_num(left) == self._safe_num(value)
|
||||
if op == ComparisonOperator.NE:
|
||||
return self._safe_num(left) != self._safe_num(value)
|
||||
if op == ComparisonOperator.LT:
|
||||
return self._safe_num(left) < self._safe_num(value)
|
||||
if op == ComparisonOperator.LE:
|
||||
return self._safe_num(left) <= self._safe_num(value)
|
||||
if op == ComparisonOperator.GT:
|
||||
return self._safe_num(left) > self._safe_num(value)
|
||||
if op == ComparisonOperator.GE:
|
||||
return self._safe_num(left) >= self._safe_num(value)
|
||||
|
||||
# String / sequence operators
|
||||
left_str = str(left) if left is not None else ""
|
||||
if op == ComparisonOperator.CONTAINS:
|
||||
return str(value) in left_str
|
||||
if op == ComparisonOperator.NOT_CONTAINS:
|
||||
return str(value) not in left_str
|
||||
if op == ComparisonOperator.START_WITH:
|
||||
return left_str.startswith(str(value))
|
||||
if op == ComparisonOperator.END_WITH:
|
||||
return left_str.endswith(str(value))
|
||||
if op == ComparisonOperator.IN:
|
||||
return left_str in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.NOT_IN:
|
||||
return left_str not in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.EMPTY:
|
||||
return not left
|
||||
if op == ComparisonOperator.NOT_EMPTY:
|
||||
return bool(left)
|
||||
|
||||
raise ValueError(f"Unsupported operator: {op}")
|
||||
|
||||
@staticmethod
|
||||
def _safe_num(v) -> float:
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
@@ -135,8 +135,7 @@ class LLMNode(BaseNode):
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=model_info.is_omni,
|
||||
support_thinking="thinking" in (model_info.capability or []),
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=model_info.model_type
|
||||
)
|
||||
@@ -214,9 +213,10 @@ class LLMNode(BaseNode):
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +52,8 @@ WorkflowNode = Union[
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode,
|
||||
CodeNode,
|
||||
DocExtractorNode
|
||||
DocExtractorNode,
|
||||
ListOperatorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +85,8 @@ class NodeFactory:
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -118,8 +118,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -71,8 +71,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
|
||||
|
||||
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
|
||||
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
|
||||
Used as fallback when HTTP request fails.
|
||||
"""
|
||||
raw_path = url.split("?")[0]
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
_, ext = os.path.splitext(name or "")
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
guessed_mime = mimetypes.guess_type(url)[0]
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": "remote_url",
|
||||
"origin_file_type": origin_file_type,
|
||||
"file_id": None,
|
||||
"name": name,
|
||||
"size": None,
|
||||
"extension": extension,
|
||||
"mime_type": guessed_mime or origin_file_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_remote_file_meta(
|
||||
url: str,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
|
||||
Falls back to URL-only parsing if the HTTP request fails.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
name = size = mime_type = extension = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.head(url, follow_redirects=True)
|
||||
if resp.status_code != 200:
|
||||
resp = await client.get(url, follow_redirects=True)
|
||||
|
||||
cl = resp.headers.get("Content-Length")
|
||||
size = int(cl) if cl else None
|
||||
|
||||
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
mime_type = ct or origin_file_type
|
||||
|
||||
cd = resp.headers.get("Content-Disposition", "")
|
||||
if "filename=" in cd:
|
||||
name = cd.split("filename=")[-1].strip('"').strip("'")
|
||||
if not name:
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
|
||||
if name:
|
||||
_, ext = os.path.splitext(name)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
if not extension and mime_type:
|
||||
ext = mimetypes.guess_extension(mime_type)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
except Exception:
|
||||
return build_file_object_dict_from_url(url, file_type, origin_file_type)
|
||||
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="remote_url",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
url=url,
|
||||
file_name=name,
|
||||
file_size=size,
|
||||
file_ext=extension,
|
||||
content_type=mime_type,
|
||||
)
|
||||
|
||||
|
||||
def build_file_object_dict_from_meta(
|
||||
file_type: str,
|
||||
transfer_method: str,
|
||||
origin_file_type: str,
|
||||
file_id: str,
|
||||
url: str,
|
||||
file_name: str | None,
|
||||
file_size: int | None,
|
||||
file_ext: str | None,
|
||||
content_type: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a FileObject dict from already-fetched FileMetadata fields."""
|
||||
ext = (file_ext or "").lstrip(".")
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": transfer_method,
|
||||
"origin_file_type": content_type or origin_file_type,
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": ext.lower() if ext else None,
|
||||
"mime_type": content_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
def resolve_local_file_object_dict(
|
||||
db,
|
||||
upload_file_id: str | uuid.UUID,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Query FileMetadata and build a FileObject dict for a local_file.
|
||||
Returns None if the file is not found or not completed.
|
||||
"""
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
fid = uuid.UUID(str(upload_file_id))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
meta = db.query(FileMetadata).filter(
|
||||
FileMetadata.id == fid,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(fid),
|
||||
url=url,
|
||||
file_name=meta.file_name,
|
||||
file_size=meta.file_size,
|
||||
file_ext=meta.file_ext,
|
||||
content_type=meta.content_type,
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ class WorkflowValidator:
|
||||
for node in nodes:
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
@@ -311,7 +311,7 @@ class WorkflowValidator:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
|
||||
@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
return {}
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
# Extended file metadata
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
is_file: bool
|
||||
|
||||
|
||||
@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
return FileObject(**value)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
|
||||
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
return self.value.model_dump(exclude={"content_cache"})
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictVariable(value)
|
||||
case VariableType.FILE:
|
||||
return FileVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
|
||||
Reference in New Issue
Block a user