feat(workflow,app): add MIME-based file handling and HTTP response files
This commit is contained in:
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
@@ -643,15 +644,18 @@ class BaseNode(ABC):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
origin_file_type=content.origin_file_type,
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
|
||||
|
||||
class HttpAuthConfig(BaseModel):
|
||||
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
|
||||
description="Http response headers"
|
||||
)
|
||||
|
||||
files: list[FileObject] = Field(
|
||||
default_factory=list,
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="SUCCESS",
|
||||
description="HTTP response body",
|
||||
|
||||
@@ -1,24 +1,146 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
import imghdr
|
||||
from email.message import Message
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
import magic
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class HttpResponse:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
self._is_file: bool | None = None
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
if self._is_file is not None:
|
||||
return self._is_file
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
|
||||
parsed_content_disposition = self.content_disposition
|
||||
if parsed_content_disposition:
|
||||
disp_type = parsed_content_disposition.get_content_disposition()
|
||||
filename = parsed_content_disposition.get_filename()
|
||||
if disp_type == "attachment" or filename:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
if content_type.startswith("application/"):
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
|
||||
):
|
||||
self._is_file = False
|
||||
return False
|
||||
try:
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
return self._is_file
|
||||
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
return self._is_file
|
||||
|
||||
@property
|
||||
def is_image(self):
|
||||
if self.is_file:
|
||||
kind = imghdr.what(None, h=self.response.content)
|
||||
return kind is not None
|
||||
return False
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return str(self.response.url)
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
if self.is_file:
|
||||
return f"{'!' if self.is_image else ''}[file]({self.url})"
|
||||
return self.response.text
|
||||
|
||||
@staticmethod
|
||||
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
|
||||
mime = magic.from_buffer(file_bytes, mime=True)
|
||||
|
||||
if mime.startswith("image"):
|
||||
return FileType.IMAGE, mime
|
||||
elif mime.startswith("video"):
|
||||
return FileType.VIDEO, mime
|
||||
elif mime.startswith("audio"):
|
||||
return FileType.AUDIO, mime
|
||||
elif mime in ["application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"text/plain"]:
|
||||
return FileType.DOCUMENT, mime
|
||||
return None, None
|
||||
|
||||
@property
|
||||
def files(self) -> list[FileObject]:
|
||||
file_type, mime_type = self.get_file_type(self.response.content)
|
||||
origin_file_type = mime_to_file_type(mime_type)
|
||||
if self.is_file and file_type and origin_file_type:
|
||||
file_obj = FileObject(
|
||||
type=file_type,
|
||||
url=self.url,
|
||||
transfer_method=TransferMethod.REMOTE_URL.value,
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
is_file=True
|
||||
)
|
||||
file_obj.set_content(self.response.content)
|
||||
return [
|
||||
file_obj
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
HTTP Request Workflow Node.
|
||||
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"files": VariableType.ARRAY_FILE,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files
|
||||
).model_dump()
|
||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||
logger.error(f"HTTP request node exception: {e}")
|
||||
|
||||
56
api/app/core/workflow/utils/file_processer.py
Normal file
56
api/app/core/workflow/utils/file_processer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
'text/x-markdown': 'document/x-markdown',
|
||||
|
||||
'application/pdf': 'document/pdf',
|
||||
|
||||
'application/msword': 'document/doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||
|
||||
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||
}
|
||||
ALLOWED_FILE_TYPES = [
|
||||
'text/plain',
|
||||
'text/markdown',
|
||||
'text/x-markdown',
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'image/jpg',
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
'image/svg+xml',
|
||||
'video/mp4',
|
||||
'video/quicktime',
|
||||
'video/x-msvideo',
|
||||
'video/x-matroska',
|
||||
'video/webm',
|
||||
'video/x-flv',
|
||||
'video/x-ms-wmv',
|
||||
'audio/mpeg',
|
||||
'audio/wav',
|
||||
'audio/ogg',
|
||||
'audio/aac',
|
||||
'audio/flac',
|
||||
'audio/mp4',
|
||||
'audio/x-ms-wma',
|
||||
'audio/x-m4a',
|
||||
]
|
||||
|
||||
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
@@ -114,9 +114,16 @@ class FileObject(BaseModel):
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
_byte_content: bytes | None = None
|
||||
|
||||
def get_content(self):
|
||||
return self._byte_content
|
||||
|
||||
def set_content(self, byte_content):
|
||||
self._byte_content = byte_content
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
@@ -45,11 +45,19 @@ class FileInput(BaseModel):
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
_content = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
def set_content(self, content: bytes):
|
||||
self._content = content
|
||||
|
||||
def get_content(self) -> bytes | None:
|
||||
return self._content
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
|
||||
@@ -8,32 +8,42 @@
|
||||
- Bedrock/Anthropic: 仅支持 base64 格式
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import PyPDF2
|
||||
import httpx
|
||||
import magic
|
||||
from docx import Document
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
TEXT_MIME = ['text/plain', 'text/x-markdown']
|
||||
PDF_MIME = ['application/pdf']
|
||||
DOC_MIME = [
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
]
|
||||
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@@ -43,7 +53,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +66,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -70,7 +80,13 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
@@ -98,44 +114,37 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
content_type = magic.from_buffer(content, mime=True)
|
||||
media_type = content_type if content_type.startswith("image/") else "image/jpeg"
|
||||
base64_data = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
@@ -152,7 +161,12 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self, file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
@@ -178,7 +192,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
return {
|
||||
"type": "image_url",
|
||||
@@ -194,7 +208,13 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
@@ -208,31 +228,35 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
content_type = magic.from_buffer(audio_data, mime=True)
|
||||
file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
@@ -262,7 +286,8 @@ PROVIDER_STRATEGIES = {
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
@@ -305,10 +330,9 @@ class MultimodalService:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file, strategy)
|
||||
@@ -355,7 +379,7 @@ class MultimodalService:
|
||||
"""
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
return await strategy.format_image(url, content=file.get_content())
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -415,11 +439,13 @@ class MultimodalService:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -454,7 +480,7 @@ class MultimodalService:
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -500,8 +526,6 @@ class MultimodalService:
|
||||
return file.url
|
||||
else:
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
@@ -519,66 +543,44 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
file: 文件输入
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in TEXT_MIME:
|
||||
return file_content.decode("utf-8")
|
||||
elif file_mime_type in PDF_MIME:
|
||||
return await self._extract_pdf_text(file_content)
|
||||
elif file_mime_type in DOC_MIME:
|
||||
return await self._extract_word_text(file_content)
|
||||
else:
|
||||
return f"[Unsupported file type: {file_mime_type}]"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
logger.error(f"Failed to load file. - {e}")
|
||||
return "[Failed to load file.]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
async def _extract_pdf_text(file_content: bytes) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
@@ -588,17 +590,11 @@ class MultimodalService:
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
async def _extract_word_text(file_content: bytes) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
|
||||
@@ -458,7 +458,7 @@ class WorkflowService:
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id),
|
||||
file_id=str(file.upload_file_id) if file.upload_file_id else None,
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
|
||||
@@ -145,6 +145,8 @@ dependencies = [
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
"modelscope>=1.34.0",
|
||||
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
||||
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
Reference in New Issue
Block a user