feat(workflow):
1. add list operator node for filtering, sorting, limiting, and extracting list items; 2. Increase the session variable to the "file" type
This commit is contained in:
@@ -1079,6 +1079,14 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@@ -53,22 +53,24 @@ async def login_for_access_token(
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
|
||||
@@ -475,7 +475,7 @@ class LangChainAgent:
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int, None]:
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
|
||||
@@ -25,8 +25,34 @@ class RedBearEmbeddings(Embeddings):
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
provider = config.provider.lower()
|
||||
# Embedding models only need connection params, never LLM-specific ones
|
||||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import httpx
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
else:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
|
||||
@@ -6,14 +6,28 @@ ChatOpenAI 在解析流式 SSE 时只取 delta.content,会丢弃 delta.reasoni
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanoChatOpenAI(ChatOpenAI):
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式透传。"""
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
|
||||
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
|
||||
reasoning = (
|
||||
getattr(message, "reasoning_content", None)
|
||||
or (message.get("reasoning_content") if isinstance(message, dict) else None)
|
||||
)
|
||||
if reasoning and result.generations:
|
||||
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
|
||||
@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
VariableAggregatorNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
)
|
||||
from app.core.workflow.nodes.list_operator.config import FilterCondition
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
|
||||
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
"not exists": ComparisonOperator.EMPTY,
|
||||
"in": ComparisonOperator.IN,
|
||||
"not in": ComparisonOperator.NOT_IN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -771,3 +778,46 @@ class DifyConverter(BaseConverter):
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
def convert_list_operator_node_config(self, node: dict) -> dict:
|
||||
"""Dify list-operator — convert variable path array to {{ }} selector format."""
|
||||
node_data = node["data"]
|
||||
variable_path = node_data.get("variable", [])
|
||||
input_list = self._process_list_variable_literal(variable_path) or ""
|
||||
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
|
||||
# Convert each condition's comparison_operator from Dify format to native
|
||||
if filter_by.get("conditions"):
|
||||
converted_conditions = []
|
||||
for cond in filter_by["conditions"]:
|
||||
converted_conditions.append({
|
||||
**cond,
|
||||
"comparison_operator": self.convert_compare_operator(
|
||||
cond.get("comparison_operator", "")
|
||||
)
|
||||
})
|
||||
filter_by = {**filter_by, "conditions": converted_conditions}
|
||||
result = {
|
||||
"input_list": input_list,
|
||||
"filter_by": filter_by,
|
||||
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
|
||||
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
|
||||
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
|
||||
}
|
||||
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_document_extractor_node_config(self, node: dict) -> dict:
|
||||
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
|
||||
|
||||
Dify document-extractor data fields:
|
||||
variable_selector: list[str] - file variable path
|
||||
"""
|
||||
node_data = node["data"]
|
||||
file_selector = self._process_list_variable_literal(
|
||||
node_data.get("variable_selector", [])
|
||||
) or ""
|
||||
result = DocExtractorNodeConfig.model_construct(
|
||||
file_selector=file_selector,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"list-operator": NodeType.LIST_OPERATOR,
|
||||
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -318,7 +318,7 @@ class VariablePool:
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
var_type: VariableType | None,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
@@ -493,6 +493,23 @@ class VariablePoolInitializer:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
# Convert FileInput-format dicts to full FileObject dicts
|
||||
if var_type == VariableType.FILE:
|
||||
if not var_value:
|
||||
continue
|
||||
var_value = await self._resolve_file_default(var_value)
|
||||
if not var_value:
|
||||
continue
|
||||
elif var_type == VariableType.ARRAY_FILE:
|
||||
if not var_value:
|
||||
var_value = []
|
||||
else:
|
||||
resolved = []
|
||||
for item in var_value:
|
||||
f = await self._resolve_file_default(item)
|
||||
if f:
|
||||
resolved.append(f)
|
||||
var_value = resolved
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
@@ -501,6 +518,17 @@ class VariablePoolInitializer:
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_default(file_def: dict) -> dict | None:
|
||||
"""Accept only already-resolved FileObject dicts (is_file=True).
|
||||
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
|
||||
"""
|
||||
if not isinstance(file_def, dict):
|
||||
return None
|
||||
if file_def.get("is_file"):
|
||||
return file_def
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
|
||||
@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -49,5 +51,7 @@ __all__ = [
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
"NoteNodeConfig",
|
||||
"ListOperatorNodeConfig",
|
||||
"DocExtractorNodeConfig",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||
"""Convert workflow FileObject to multimodal FileInput."""
|
||||
file_type = f.origin_file_type or ""
|
||||
# Prefer mime_type for more accurate type detection
|
||||
if not file_type and f.mime_type:
|
||||
file_type = f.mime_type
|
||||
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
|
||||
if resolved_type != FileType.DOCUMENT:
|
||||
raise ValueError(
|
||||
f"Document extractor only supports document files, got type '{f.type}' "
|
||||
f"(name={f.name or f.file_id or f.url})"
|
||||
)
|
||||
return FileInput(
|
||||
type=FileType.DOCUMENT,
|
||||
type=resolved_type,
|
||||
transfer_method=TransferMethod(f.transfer_method),
|
||||
url=f.url or None,
|
||||
upload_file_id=f.file_id or None,
|
||||
file_type=f.origin_file_type or "",
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
svc = MultimodalService(db)
|
||||
for f in files:
|
||||
label = f.name or f.url or f.file_id
|
||||
try:
|
||||
file_input = _file_object_to_file_input(f)
|
||||
# Ensure URL is populated for local files
|
||||
@@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode):
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||
f"Node {self.node_id}: failed to extract file '{label}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chunks.append("")
|
||||
|
||||
@@ -24,6 +24,7 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
IN = "in"
|
||||
NOT_IN = "not_in"
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
|
||||
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
|
||||
value: str | list[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: list[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: str = "asc" # "asc" | "desc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ExtractConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
serial: str = "1" # 1-based index string, e.g. "1" = first
|
||||
|
||||
|
||||
class ListOperatorNodeConfig(BaseNodeConfig):
|
||||
"""
|
||||
List Operator node config.
|
||||
Operation order: filter -> extract -> order -> limit
|
||||
"""
|
||||
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
|
||||
filter_by: FilterBy = Field(default_factory=FilterBy)
|
||||
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
|
||||
limit: Limit = Field(default_factory=Limit)
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# File object fields that hold string values
|
||||
_FILE_STRING_KEYS = {"name", "extension", "mime_type", "url", "transfer_method", "origin_file_type", "file_id"}
|
||||
_FILE_NUMBER_KEYS = {"size"}
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ListOperatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"result": VariableType.ANY,
|
||||
"first_record": VariableType.ANY,
|
||||
"last_record": VariableType.ANY,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = ListOperatorNodeConfig(**self.config)
|
||||
cfg = self.typed_config
|
||||
|
||||
# Resolve input variable from path selector
|
||||
items: list = self.get_variable(cfg.input_list, variable_pool)
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
|
||||
|
||||
result = list(items)
|
||||
|
||||
# 1. Filter
|
||||
if cfg.filter_by.enabled and cfg.filter_by.conditions:
|
||||
for condition in cfg.filter_by.conditions:
|
||||
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
|
||||
|
||||
# 2. Extract (take single item by 1-based serial index)
|
||||
if cfg.extract_by.enabled:
|
||||
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
|
||||
idx = int(serial_str) - 1
|
||||
if idx < 0 or idx >= len(result):
|
||||
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
|
||||
result = [result[idx]]
|
||||
|
||||
# 3. Order
|
||||
if cfg.order_by.enabled and cfg.order_by.key:
|
||||
reverse = cfg.order_by.value == "desc"
|
||||
key_fn = self._make_sort_key(cfg.order_by.key)
|
||||
result = sorted(result, key=key_fn, reverse=reverse)
|
||||
|
||||
# 4. Limit (take first N)
|
||||
if cfg.limit.enabled and cfg.limit.size > 0:
|
||||
result = result[:cfg.limit.size]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"first_record": result[0] if result else None,
|
||||
"last_record": result[-1] if result else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
|
||||
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
|
||||
Otherwise return the raw string."""
|
||||
import re
|
||||
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
|
||||
if m:
|
||||
resolved = variable_pool.get_value(value, default=value, strict=False)
|
||||
return resolved
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _make_sort_key(key: str):
|
||||
def key_fn(item):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key) or ""
|
||||
return item
|
||||
return key_fn
|
||||
|
||||
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
|
||||
op = cond.comparison_operator
|
||||
value = cond.value
|
||||
|
||||
# Resolve value if it's a variable reference {{ namespace.key }}
|
||||
if isinstance(value, str):
|
||||
value = self._resolve_value(value, variable_pool)
|
||||
|
||||
# Resolve left value
|
||||
if isinstance(item, dict):
|
||||
left = item.get(cond.key)
|
||||
else:
|
||||
left = item # primitive array: compare element directly
|
||||
|
||||
# Numeric operators
|
||||
if op == ComparisonOperator.EQ:
|
||||
return self._safe_num(left) == self._safe_num(value)
|
||||
if op == ComparisonOperator.NE:
|
||||
return self._safe_num(left) != self._safe_num(value)
|
||||
if op == ComparisonOperator.LT:
|
||||
return self._safe_num(left) < self._safe_num(value)
|
||||
if op == ComparisonOperator.LE:
|
||||
return self._safe_num(left) <= self._safe_num(value)
|
||||
if op == ComparisonOperator.GT:
|
||||
return self._safe_num(left) > self._safe_num(value)
|
||||
if op == ComparisonOperator.GE:
|
||||
return self._safe_num(left) >= self._safe_num(value)
|
||||
|
||||
# String / sequence operators
|
||||
left_str = str(left) if left is not None else ""
|
||||
if op == ComparisonOperator.CONTAINS:
|
||||
return str(value) in left_str
|
||||
if op == ComparisonOperator.NOT_CONTAINS:
|
||||
return str(value) not in left_str
|
||||
if op == ComparisonOperator.START_WITH:
|
||||
return left_str.startswith(str(value))
|
||||
if op == ComparisonOperator.END_WITH:
|
||||
return left_str.endswith(str(value))
|
||||
if op == ComparisonOperator.IN:
|
||||
return left_str in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.NOT_IN:
|
||||
return left_str not in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.EMPTY:
|
||||
return not left
|
||||
if op == ComparisonOperator.NOT_EMPTY:
|
||||
return bool(left)
|
||||
|
||||
raise ValueError(f"Unsupported operator: {op}")
|
||||
|
||||
@staticmethod
|
||||
def _safe_num(v) -> float:
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
@@ -135,8 +135,7 @@ class LLMNode(BaseNode):
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=model_info.is_omni,
|
||||
support_thinking="thinking" in (model_info.capability or []),
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=model_info.model_type
|
||||
)
|
||||
@@ -214,9 +213,10 @@ class LLMNode(BaseNode):
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +52,8 @@ WorkflowNode = Union[
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode,
|
||||
CodeNode,
|
||||
DocExtractorNode
|
||||
DocExtractorNode,
|
||||
ListOperatorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +85,8 @@ class NodeFactory:
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -118,8 +118,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -71,8 +71,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
|
||||
|
||||
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
|
||||
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
|
||||
Used as fallback when HTTP request fails.
|
||||
"""
|
||||
raw_path = url.split("?")[0]
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
_, ext = os.path.splitext(name or "")
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
guessed_mime = mimetypes.guess_type(url)[0]
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": "remote_url",
|
||||
"origin_file_type": origin_file_type,
|
||||
"file_id": None,
|
||||
"name": name,
|
||||
"size": None,
|
||||
"extension": extension,
|
||||
"mime_type": guessed_mime or origin_file_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_remote_file_meta(
|
||||
url: str,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
|
||||
Falls back to URL-only parsing if the HTTP request fails.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
name = size = mime_type = extension = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.head(url, follow_redirects=True)
|
||||
if resp.status_code != 200:
|
||||
resp = await client.get(url, follow_redirects=True)
|
||||
|
||||
cl = resp.headers.get("Content-Length")
|
||||
size = int(cl) if cl else None
|
||||
|
||||
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
mime_type = ct or origin_file_type
|
||||
|
||||
cd = resp.headers.get("Content-Disposition", "")
|
||||
if "filename=" in cd:
|
||||
name = cd.split("filename=")[-1].strip('"').strip("'")
|
||||
if not name:
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
|
||||
if name:
|
||||
_, ext = os.path.splitext(name)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
if not extension and mime_type:
|
||||
ext = mimetypes.guess_extension(mime_type)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
except Exception:
|
||||
return build_file_object_dict_from_url(url, file_type, origin_file_type)
|
||||
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="remote_url",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
url=url,
|
||||
file_name=name,
|
||||
file_size=size,
|
||||
file_ext=extension,
|
||||
content_type=mime_type,
|
||||
)
|
||||
|
||||
|
||||
def build_file_object_dict_from_meta(
|
||||
file_type: str,
|
||||
transfer_method: str,
|
||||
origin_file_type: str,
|
||||
file_id: str,
|
||||
url: str,
|
||||
file_name: str | None,
|
||||
file_size: int | None,
|
||||
file_ext: str | None,
|
||||
content_type: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a FileObject dict from already-fetched FileMetadata fields."""
|
||||
ext = (file_ext or "").lstrip(".")
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": transfer_method,
|
||||
"origin_file_type": content_type or origin_file_type,
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": ext.lower() if ext else None,
|
||||
"mime_type": content_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
def resolve_local_file_object_dict(
|
||||
db,
|
||||
upload_file_id: str | uuid.UUID,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Query FileMetadata and build a FileObject dict for a local_file.
|
||||
Returns None if the file is not found or not completed.
|
||||
"""
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
fid = uuid.UUID(str(upload_file_id))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
meta = db.query(FileMetadata).filter(
|
||||
FileMetadata.id == fid,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(fid),
|
||||
url=url,
|
||||
file_name=meta.file_name,
|
||||
file_size=meta.file_size,
|
||||
file_ext=meta.file_ext,
|
||||
content_type=meta.content_type,
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ class WorkflowValidator:
|
||||
for node in nodes:
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
@@ -311,7 +311,7 @@ class WorkflowValidator:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
|
||||
@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
return {}
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
# Extended file metadata
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
is_file: bool
|
||||
|
||||
|
||||
@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
return FileObject(**value)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
|
||||
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
return self.value.model_dump(exclude={"content_cache"})
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictVariable(value)
|
||||
case VariableType.FILE:
|
||||
return FileVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
|
||||
@@ -640,6 +640,7 @@ class CitationSource(BaseModel):
|
||||
class DraftRunResponse(BaseModel):
|
||||
"""试运行响应(非流式)"""
|
||||
message: str = Field(..., description="AI 回复消息")
|
||||
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||
@@ -647,6 +648,12 @@ class DraftRunResponse(BaseModel):
|
||||
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
|
||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
data = super().model_dump(**kwargs)
|
||||
if not data.get("reasoning_content"):
|
||||
data.pop("reasoning_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class OpeningResponse(BaseModel):
|
||||
"""应用开场白响应"""
|
||||
|
||||
@@ -16,7 +16,6 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
@@ -453,22 +452,70 @@ class WorkflowService:
|
||||
"success_rate": completed / total if total > 0 else 0
|
||||
}
|
||||
|
||||
async def _resolve_variables_file_defaults(
|
||||
self,
|
||||
variables: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert FileInput-format defaults in workflow variables to full FileObject dicts."""
|
||||
from app.core.workflow.utils.file_processor import (
|
||||
resolve_local_file_object_dict,
|
||||
fetch_remote_file_meta,
|
||||
)
|
||||
|
||||
async def _resolve_one(item: dict) -> dict | None:
|
||||
if not isinstance(item, dict) or item.get("is_file"):
|
||||
return item
|
||||
transfer_method = item.get("transfer_method", "remote_url")
|
||||
file_type = item.get("type", "document")
|
||||
origin_file_type = item.get("file_type") or file_type
|
||||
if transfer_method == "remote_url":
|
||||
url = item.get("url", "")
|
||||
return await fetch_remote_file_meta(url, file_type, origin_file_type) if url else None
|
||||
else:
|
||||
return resolve_local_file_object_dict(self.db, item.get("upload_file_id"), file_type, origin_file_type)
|
||||
|
||||
result = []
|
||||
for var_def in variables:
|
||||
var_type = var_def.get("type", "")
|
||||
default = var_def.get("default")
|
||||
if var_type == "file" and isinstance(default, dict) and not default.get("is_file"):
|
||||
var_def = {**var_def, "default": await _resolve_one(default)}
|
||||
elif var_type == "array[file]" and isinstance(default, list):
|
||||
resolved = []
|
||||
for item in default:
|
||||
r = await _resolve_one(item)
|
||||
if r is not None:
|
||||
resolved.append(r)
|
||||
var_def = {**var_def, "default": resolved}
|
||||
result.append(var_def)
|
||||
return result
|
||||
|
||||
async def _handle_file_input(self, files: list[FileInput]):
|
||||
if not files:
|
||||
return []
|
||||
|
||||
from app.core.workflow.utils.file_processor import (
|
||||
resolve_local_file_object_dict,
|
||||
build_file_object_dict_from_meta,
|
||||
fetch_remote_file_meta,
|
||||
)
|
||||
|
||||
files_struct = []
|
||||
for file in files:
|
||||
files_struct.append(
|
||||
FileObject(
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id) if file.upload_file_id else None,
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
)
|
||||
url = await self.multimodal_service.get_file_url(file)
|
||||
file_type = str(file.type)
|
||||
origin_file_type = file.file_type or file_type
|
||||
|
||||
if file.transfer_method.value == "local_file" and file.upload_file_id:
|
||||
fo = resolve_local_file_object_dict(self.db, file.upload_file_id, file_type, origin_file_type)
|
||||
files_struct.append(fo or build_file_object_dict_from_meta(
|
||||
file_type=file_type, transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(file.upload_file_id), url=url,
|
||||
file_name=None, file_size=None, file_ext=None, content_type=None,
|
||||
))
|
||||
else:
|
||||
files_struct.append(await fetch_remote_file_meta(url, file_type, origin_file_type))
|
||||
return files_struct
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -480,21 +480,21 @@ def create_workspace_invite(
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
# if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=invited_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=invited_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
# 检查是否已有待处理的邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
|
||||
Reference in New Issue
Block a user