From 99e94b3567d27f1532da73cf9f40755c4516c979 Mon Sep 17 00:00:00 2001 From: Eternity <61316157+myhMARS@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:28:16 +0800 Subject: [PATCH] feat(workflow,app): add MIME-based file handling and HTTP response files --- api/Dockerfile | 3 +- api/app/core/workflow/nodes/base_node.py | 20 +- .../workflow/nodes/http_request/config.py | 6 + .../core/workflow/nodes/http_request/node.py | 131 ++++++++- api/app/core/workflow/utils/file_processer.py | 56 ++++ .../core/workflow/variable/base_variable.py | 9 +- api/app/schemas/app_schema.py | 8 + api/app/services/multimodal_service.py | 252 +++++++++--------- api/app/services/workflow_service.py | 2 +- api/pyproject.toml | 2 + 10 files changed, 347 insertions(+), 142 deletions(-) create mode 100644 api/app/core/workflow/utils/file_processer.py diff --git a/api/Dockerfile b/api/Dockerfile index f6c082d2..a38739ca 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libjemalloc-dev && \ apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ - apt install -y ghostscript + apt install -y ghostscript && \ + apt install -y libmagic1 RUN if [ "$NEED_MIRROR" == "1" ]; then \ pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 496454ba..39c7887b 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from abc import ABC, abstractmethod from datetime import datetime from functools import cached_property @@ -643,15 +644,18 @@ class BaseNode(ABC): return content.content_cache[provider] with get_db_read() as db: multimodel_service = MultimodalService(db, provider, is_omni=is_omni) - message = await multimodel_service.process_files( - [FileInput.model_construct( - type=content.type, - url=content.url, - transfer_method=content.transfer_method, - file_type=content.origin_file_type, - upload_file_id=content.file_id - )] + file_obj = FileInput( + type=content.type, + url=content.url, + transfer_method=content.transfer_method, + origin_file_type=content.origin_file_type, + upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, ) + file_obj.set_content(content.get_content()) + message = await multimodel_service.process_files( + [file_obj] + ) + content.set_content(file_obj.get_content()) if message: content.content_cache[provider] = message return message diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index 9b41d9f2..fe38fafb 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle +from app.core.workflow.variable.base_variable import FileObject class HttpAuthConfig(BaseModel): @@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel): description="Http response headers" ) + files: list[FileObject] = Field( + default_factory=list, + description="List of files", + ) + output: str = Field( default="SUCCESS", description="HTTP response body", diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index e6c00eff..23378c83 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -1,24 +1,146 @@ import asyncio import json import logging +import mimetypes import uuid +import imghdr +from email.message import Message from typing import Any, Callable, Coroutine import httpx -# import filetypes # TODO: File support (Feature) from httpx import AsyncClient, Response, Timeout +import magic from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.utils.file_processer import mime_to_file_type +from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable +from app.schemas import FileType, TransferMethod logger = logging.getLogger(__file__) +class HttpResponse: + def __init__(self, response: httpx.Response): + self.response = response + self.headers = dict(response.headers) + + self._is_file: bool | None = None + + @property + def content_type(self) -> str: + return self.headers.get("content-type", "") + + @property + def content_disposition(self) -> Message | None: + content_disposition = self.headers.get("content-disposition", "") + if content_disposition: + msg = Message() + msg["content-disposition"] = content_disposition + return msg + return None + + @property + def is_file(self) -> bool: + if self._is_file is not None: + return self._is_file + content_type = self.content_type.split(";")[0].strip().lower() + + parsed_content_disposition = self.content_disposition + if parsed_content_disposition: + disp_type = parsed_content_disposition.get_content_disposition() + filename = parsed_content_disposition.get_filename() + if disp_type == "attachment" or filename: + self._is_file = True + return True + + if content_type.startswith("text/") and "csv" not in content_type: + return False + + if content_type.startswith("application/"): + if any( + text_type in content_type + for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"} + ): + self._is_file = False + return False + try: + content_sample = self.response.content[:1024] + content_sample.decode("utf-8") + text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") + if any(marker in content_sample for marker in text_markers): + return False + except UnicodeDecodeError: + self._is_file = True + return True + + main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) + if main_type: + self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video") + return self._is_file + self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/")) + return self._is_file + + @property + def is_image(self): + if self.is_file: + kind = imghdr.what(None, h=self.response.content) + return kind is not None + return False + + @property + def url(self) -> str: + return str(self.response.url) + + @property + def body(self) -> str: + if self.is_file: + return f"{'!' if self.is_image else ''}[file]({self.url})" + return self.response.text + + @staticmethod + def get_file_type(file_bytes) -> tuple[FileType | None, str | None]: + mime = magic.from_buffer(file_bytes, mime=True) + + if mime.startswith("image"): + return FileType.IMAGE, mime + elif mime.startswith("video"): + return FileType.VIDEO, mime + elif mime.startswith("audio"): + return FileType.AUDIO, mime + elif mime in ["application/pdf", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "text/plain"]: + return FileType.DOCUMENT, mime + return None, None + + @property + def files(self) -> list[FileObject]: + file_type, mime_type = self.get_file_type(self.response.content) + origin_file_type = mime_to_file_type(mime_type) + if self.is_file and file_type and origin_file_type: + file_obj = FileObject( + type=file_type, + url=self.url, + transfer_method=TransferMethod.REMOTE_URL.value, + origin_file_type=origin_file_type, + file_id=None, + is_file=True + ) + file_obj.set_content(self.response.content) + return [ + file_obj + ] + return [] + + class HttpRequestNode(BaseNode): """ HTTP Request Workflow Node. @@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode): "body": VariableType.STRING, "status_code": VariableType.NUMBER, "headers": VariableType.OBJECT, + "files": VariableType.ARRAY_FILE, "output": VariableType.STRING } @@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode): ) resp.raise_for_status() logger.info(f"Node {self.node_id}: HTTP request succeeded") + response = HttpResponse(resp) return HttpRequestNodeOutput( - body=resp.text, + body=response.body, status_code=resp.status_code, headers=resp.headers, + files=response.files ).model_dump() except (httpx.HTTPStatusError, httpx.RequestError) as e: logger.error(f"HTTP request node exception: {e}") diff --git a/api/app/core/workflow/utils/file_processer.py b/api/app/core/workflow/utils/file_processer.py new file mode 100644 index 00000000..ae406ab0 --- /dev/null +++ b/api/app/core/workflow/utils/file_processer.py @@ -0,0 +1,56 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/3/10 13:36 +TRANSFORM_FILE_TYPE = { + 'text/plain': 'document/text', + 'text/markdown': 'document/markdown', + 'text/x-markdown': 'document/x-markdown', + + 'application/pdf': 'document/pdf', + + 'application/msword': 'document/doc', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx', + + 'application/vnd.ms-powerpoint': 'document/ppt', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx', +} +ALLOWED_FILE_TYPES = [ + 'text/plain', + 'text/markdown', + 'text/x-markdown', + 'application/pdf', + 'application/msword', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.ms-powerpoint', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + 'image/jpg', + 'image/jpeg', + 'image/png', + 'image/gif', + 'image/bmp', + 'image/webp', + 'image/svg+xml', + 'video/mp4', + 'video/quicktime', + 'video/x-msvideo', + 'video/x-matroska', + 'video/webm', + 'video/x-flv', + 'video/x-ms-wmv', + 'audio/mpeg', + 'audio/wav', + 'audio/ogg', + 'audio/aac', + 'audio/flac', + 'audio/mp4', + 'audio/x-ms-wma', + 'audio/x-m4a', +] + + +def mime_to_file_type(mime_type): + if mime_type not in ALLOWED_FILE_TYPES: + return None + + return TRANSFORM_FILE_TYPE.get(mime_type, mime_type) diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index dd821ea7..aea40cf6 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -114,9 +114,16 @@ class FileObject(BaseModel): file_id: str | None content_cache: dict = Field(default_factory=dict) - is_file: bool + _byte_content: bytes | None = None + + def get_content(self): + return self._byte_content + + def set_content(self, byte_content): + self._byte_content = byte_content + class BaseVariable(ABC): """Abstract base class for all workflow variables. diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index f073a200..c0482ec3 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -45,11 +45,19 @@ class FileInput(BaseModel): url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + _content = None + def __init__(self, **data): if "type" in data: data['file_type'] = data['type'] super().__init__(**data) + def set_content(self, content: bytes): + self._content = content + + def get_content(self) -> bytes | None: + return self._content + @field_validator("type", mode="before") @classmethod def validate_type(cls, v): diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 9b06c287..fffca2e5 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -8,32 +8,42 @@ - Bedrock/Anthropic: 仅支持 base64 格式 - OpenAI: 支持 URL 和 base64 格式 """ -import uuid -import httpx import base64 -from typing import List, Dict, Any, Optional -from abc import ABC, abstractmethod -from sqlalchemy.orm import Session -from docx import Document import io -import PyPDF2 +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional + +import PyPDF2 +import httpx +import magic +from docx import Document +from sqlalchemy.orm import Session -from app.core.logging_config import get_business_logger -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.schemas.app_schema import FileInput, FileType, TransferMethod -from app.models.file_metadata_model import FileMetadata from app.core.config import settings +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.models.file_metadata_model import FileMetadata +from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.services.audio_transcription_service import AudioTranscriptionService logger = get_business_logger() +TEXT_MIME = ['text/plain', 'text/x-markdown'] +PDF_MIME = ['application/pdf'] +DOC_MIME = [ + 'application/msword', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' +] + class MultimodalFormatStrategy(ABC): """多模态格式策略基类""" + def __init__(self, file: FileInput): + self.file = file @abstractmethod - async def format_image(self, url: str) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: """格式化图片""" pass @@ -43,7 +53,7 @@ class MultimodalFormatStrategy(ABC): pass @abstractmethod - async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]: + async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]: """格式化音频""" pass @@ -56,7 +66,7 @@ class MultimodalFormatStrategy(ABC): class DashScopeFormatStrategy(MultimodalFormatStrategy): """通义千问策略""" - async def format_image(self, url: str) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: """通义千问图片格式:{"type": "image", "image": "url"}""" return { "type": "image", @@ -70,7 +80,13 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): "text": f"\n{text}\n" } - async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + async def format_audio( + self, + file_type: str, + url: str, + content: bytes | None = None, + transcription: Optional[str] = None + ) -> Dict[str, Any]: """ 通义千问音频格式 - 原生支持: qwen-audio 系列 @@ -98,44 +114,37 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): class BedrockFormatStrategy(MultimodalFormatStrategy): """Bedrock/Anthropic 策略""" - async def format_image(self, url: str) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} """ - from mimetypes import guess_type logger.info(f"下载并编码图片: {url}") # 下载图片 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) - response.raise_for_status() - - # 获取图片数据 - image_data = response.content + if content is None: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + content = response.content + self.file.set_content(content) # 确定 media type - content_type = response.headers.get("content-type") - if content_type and content_type.startswith("image/"): - media_type = content_type - else: - guessed_type, _ = guess_type(url) - media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" + content_type = magic.from_buffer(content, mime=True) + media_type = content_type if content_type.startswith("image/") else "image/jpeg" + base64_data = base64.b64encode(content).decode("utf-8") - # 转换为 base64 - base64_data = base64.b64encode(image_data).decode("utf-8") + logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - - return { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": base64_data - } + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": base64_data } + } async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: """Bedrock/Anthropic 文档格式(需要 base64 编码)""" @@ -152,7 +161,12 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): } } - async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + async def format_audio( + self, file_type: str, + url: str, + content: bytes | None = None, + transcription: Optional[str] = None + ) -> Dict[str, Any]: """ Bedrock/Anthropic 音频格式 不支持原生音频,必须转录为文本 @@ -178,7 +192,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): class OpenAIFormatStrategy(MultimodalFormatStrategy): """OpenAI 策略""" - async def format_image(self, url: str) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" return { "type": "image_url", @@ -194,7 +208,13 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): "text": f"\n{text}\n" } - async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + async def format_audio( + self, + file_type: str, + url: str, + content: bytes | None = None, + transcription: Optional[str] = None + ) -> Dict[str, Any]: """ OpenAI 音频格式 - gpt-4o-audio 系列支持原生音频(需要 base64 编码) @@ -208,31 +228,35 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): # OpenAI 音频需要 base64 编码 try: - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) - response.raise_for_status() - audio_data = response.content - base64_audio = base64.b64encode(audio_data).decode('utf-8') - # 1. 优先从 file_type (MIME) 取扩展名 - file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None - # 2. 从响应头 content-type 取 - if not file_ext: - ct = response.headers.get("content-type", "") - file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None - # 3. 从 URL 路径取扩展名 - if not file_ext: - file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None - # 4. 默认 wav - # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} - file_ext = "wav" if not file_ext else file_ext + audio_data = content + if content is None: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + audio_data = response.content + self.file.set_content(audio_data) + base64_audio = base64.b64encode(audio_data).decode('utf-8') - return { - "type": "input_audio", - "input_audio": { - "data": f"data:;base64,{base64_audio}", - "format": file_ext - } + # 1. 优先从 file_type (MIME) 取扩展名 + file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None + # 2. 从响应头 content-type 取 + if not file_ext: + content_type = magic.from_buffer(audio_data, mime=True) + file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None + # 3. 从 URL 路径取扩展名 + if not file_ext: + file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None + # 4. 默认 wav + # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} + file_ext = "wav" if not file_ext else file_ext + + return { + "type": "input_audio", + "input_audio": { + "data": f"data:;base64,{base64_audio}", + "format": file_ext } + } except Exception as e: logger.error(f"下载音频失败: {e}") return { @@ -262,7 +286,8 @@ PROVIDER_STRATEGIES = { class MultimodalService: """多模态文件处理服务""" - def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False): + def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, + enable_audio_transcription: bool = False, is_omni: bool = False): """ 初始化多模态服务 @@ -305,10 +330,9 @@ class MultimodalService: logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") strategy_class = DashScopeFormatStrategy - strategy = strategy_class() - result = [] for idx, file in enumerate(files): + strategy = strategy_class(file) try: if file.type == FileType.IMAGE: content = await self._process_image(file, strategy) @@ -355,7 +379,7 @@ class MultimodalService: """ try: url = await self.get_file_url(file) - return await strategy.format_image(url) + return await strategy.format_image(url, content=file.get_content()) except Exception as e: logger.error(f"处理图片失败: {e}", exc_info=True) return { @@ -415,11 +439,13 @@ class MultimodalService: # 远程文档暂不支持提取 return { "type": "text", - "text": f"\n[远程文档,暂不支持内容提取]\n" + "text": f"\n{await self._extract_document_text(file)}\n" } else: # 本地文件,提取文本内容 - text = await self._extract_document_text(file.upload_file_id) + server_url = settings.FILE_LOCAL_SERVER_URL + file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" + text = await self._extract_document_text(file) file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file.upload_file_id ).first() @@ -454,7 +480,7 @@ class MultimodalService: else: logger.warning(f"Provider {self.provider} 不支持音频转文本") - return await strategy.format_audio(file.file_type, url, transcription) + return await strategy.format_audio(file.file_type, url, file.get_content(), transcription) except Exception as e: logger.error(f"处理音频失败: {e}", exc_info=True) return { @@ -500,8 +526,6 @@ class MultimodalService: return file.url else: file_id = file.upload_file_id - print("="*50) - print("file_id",file_id) # 查询 FileMetadata file_metadata = self.db.query(FileMetadata).filter( @@ -519,66 +543,44 @@ class MultimodalService: server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" - async def _extract_document_text(self, file_id: uuid.UUID) -> str: + async def _extract_document_text(self, file: FileInput) -> str: """ 提取文档文本内容 Args: - file_id: 文件ID + file: 文件输入 Returns: str: 提取的文本内容 """ - file_metadata = self.db.query(FileMetadata).filter( - FileMetadata.id == file_id, - FileMetadata.status == "completed" - ).first() - - if not file_metadata: - raise BusinessException( - f"文件不存在或已删除: {file_id}", - BizCode.NOT_FOUND - ) - - file_ext = file_metadata.file_ext.lower() - server_url = settings.FILE_LOCAL_SERVER_URL - file_url = f"{server_url}/storage/permanent/{file_id}" - - if file_ext in ['.txt', '.md', '.markdown']: - return await self._read_text_file(file_url) - elif file_ext == '.pdf': - return await self._extract_pdf_text(file_url) - elif file_ext in ['.doc', '.docx']: - return await self._extract_word_text(file_url) - else: - return f"[不支持的文档格式: {file_ext}]" - - @staticmethod - async def _read_text_file(file_url: str) -> str: - """读取纯文本文件""" try: - # 下载文件 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(file_url) - response.raise_for_status() - return response.text + file_content = file.get_content() + if not file_content: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file.url) + response.raise_for_status() + file_content = response.content + file.set_content(file_content) + file_mime_type = magic.from_buffer(file_content, mime=True) + if file_mime_type in TEXT_MIME: + return file_content.decode("utf-8") + elif file_mime_type in PDF_MIME: + return await self._extract_pdf_text(file_content) + elif file_mime_type in DOC_MIME: + return await self._extract_word_text(file_content) + else: + return f"[Unsupported file type: {file_mime_type}]" except Exception as e: - logger.error(f"读取文本文件失败: {e}") - return f"[文件读取失败: {str(e)}]" + logger.error(f"Failed to load file. - {e}") + return "[Failed to load file.]" @staticmethod - async def _extract_pdf_text(file_url: str) -> str: + async def _extract_pdf_text(file_content: bytes) -> str: """提取 PDF 文本""" try: - # 下载 PDF 文件 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(file_url) - response.raise_for_status() - pdf_data = response.content - # 使用 BytesIO 读取 PDF text_parts = [] - pdf_file = io.BytesIO(pdf_data) + pdf_file = io.BytesIO(file_content) pdf_reader = PyPDF2.PdfReader(pdf_file) for page in pdf_reader.pages: text_parts.append(page.extract_text()) @@ -588,17 +590,11 @@ class MultimodalService: return f"[PDF 提取失败: {str(e)}]" @staticmethod - async def _extract_word_text(file_url: str) -> str: + async def _extract_word_text(file_content: bytes) -> str: """提取 Word 文档文本""" try: - # 下载 Word 文件 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(file_url) - response.raise_for_status() - word_data = response.content - # 使用 BytesIO 读取 Word 文档 - word_file = io.BytesIO(word_data) + word_file = io.BytesIO(file_content) doc = Document(word_file) text_parts = [paragraph.text for paragraph in doc.paragraphs] return '\n'.join(text_parts) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index eaf78b90..4e7268d3 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -458,7 +458,7 @@ class WorkflowService: type=file.type, url=await self.multimodal_service.get_file_url(file), transfer_method=file.transfer_method, - file_id=str(file.upload_file_id), + file_id=str(file.upload_file_id) if file.upload_file_id else None, origin_file_type=file.file_type, is_file=True ).model_dump() diff --git a/api/pyproject.toml b/api/pyproject.toml index 0bb232c3..e6fddea8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -145,6 +145,8 @@ dependencies = [ "lxml>=4.9.0", "httpx>=0.28.0", "modelscope>=1.34.0", + "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", + "python-magic-bin>=0.4.14; sys_platform=='win32'", ] [tool.pytest.ini_options]