fix(app):
1.The end users are still bound to the app. 2. Multi-modal file support includes xlsx, csv, and json. 3. The file routing protocol is consistent with the page routing.
This commit is contained in:
@@ -537,6 +537,7 @@ async def draft_run(
|
|||||||
# 先获取 app 的 workspace_id
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
workspace_id=app.workspace_id,
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
)
|
)
|
||||||
@@ -869,6 +870,7 @@ async def draft_run_compare(
|
|||||||
# 先获取 app 的 workspace_id
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
workspace_id=app.workspace_id,
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -47,6 +47,19 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _match_scheme(request: Request, url: str) -> str:
|
||||||
|
"""
|
||||||
|
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||||
|
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||||
|
"""
|
||||||
|
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||||
|
if url.startswith("http://") and incoming_scheme == "https":
|
||||||
|
return "https://" + url[7:]
|
||||||
|
if url.startswith("https://") and incoming_scheme == "http":
|
||||||
|
return "http://" + url[8:]
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/files", response_model=ApiResponse)
|
@router.post("/files", response_model=ApiResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -280,6 +293,7 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -327,6 +341,7 @@ async def download_file(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -400,6 +415,7 @@ async def delete_file(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||||
async def get_file_url(
|
async def get_file_url(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = None,
|
expires: int = None,
|
||||||
permanent: bool = False,
|
permanent: bool = False,
|
||||||
@@ -463,6 +479,7 @@ async def get_file_url(
|
|||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||||
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
return success(
|
return success(
|
||||||
@@ -484,6 +501,7 @@ async def get_file_url(
|
|||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = 0,
|
expires: int = 0,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
@@ -555,6 +573,7 @@ async def public_download_file(
|
|||||||
# For remote storage, redirect to presigned URL
|
# For remote storage, redirect to presigned URL
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -566,6 +585,7 @@ async def public_download_file(
|
|||||||
|
|
||||||
@router.get("/permanent/{file_id}", response_model=Any)
|
@router.get("/permanent/{file_id}", response_model=Any)
|
||||||
async def permanent_download_file(
|
async def permanent_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
@@ -625,6 +645,7 @@ async def permanent_download_file(
|
|||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ def list_conversations(
|
|||||||
app_service = AppService(db)
|
app_service = AppService(db)
|
||||||
app = app_service._get_app_or_404(share.app_id)
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=share.app_id,
|
||||||
workspace_id=app.workspace_id,
|
workspace_id=app.workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
@@ -315,6 +316,7 @@ async def chat(
|
|||||||
app = app_service._get_app_or_404(share.app_id)
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
workspace_id = app.workspace_id
|
workspace_id = app.workspace_id
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=share.app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id
|
original_user_id=user_id
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ async def chat(
|
|||||||
workspace_id = app.workspace_id
|
workspace_id = app.workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class EndUserRepository:
|
|||||||
|
|
||||||
def get_or_create_end_user(
|
def get_or_create_end_user(
|
||||||
self,
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
other_id: str,
|
other_id: str,
|
||||||
original_user_id: Optional[str] = None
|
original_user_id: Optional[str] = None
|
||||||
@@ -74,6 +75,7 @@ class EndUserRepository:
|
|||||||
"""获取或创建终端用户
|
"""获取或创建终端用户
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
other_id: 第三方ID
|
other_id: 第三方ID
|
||||||
original_user_id: 原始用户ID (存储到 other_id)
|
original_user_id: 原始用户ID (存储到 other_id)
|
||||||
@@ -92,10 +94,14 @@ class EndUserRepository:
|
|||||||
|
|
||||||
if end_user:
|
if end_user:
|
||||||
db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}")
|
db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}")
|
||||||
|
end_user.app_id=app_id
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(end_user)
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
# 创建新用户
|
# 创建新用户
|
||||||
end_user = EndUser(
|
end_user = EndUser(
|
||||||
|
app_id=app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,9 +14,13 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import httpx
|
import httpx
|
||||||
import magic
|
import magic
|
||||||
|
import openpyxl
|
||||||
from docx import Document
|
from docx import Document
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -39,6 +43,13 @@ DOC_MIME = [
|
|||||||
'application/msword',
|
'application/msword',
|
||||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||||
]
|
]
|
||||||
|
XLSX_MIME = [
|
||||||
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||||
|
'application/vnd.ms-excel',
|
||||||
|
'application/zip'
|
||||||
|
]
|
||||||
|
CSV_MIME = ['text/csv', 'application/csv']
|
||||||
|
JSON_MIME = ['application/json']
|
||||||
|
|
||||||
|
|
||||||
class MultimodalFormatStrategy(ABC):
|
class MultimodalFormatStrategy(ABC):
|
||||||
@@ -577,6 +588,12 @@ class MultimodalService:
|
|||||||
return await self._extract_pdf_text(file_content)
|
return await self._extract_pdf_text(file_content)
|
||||||
elif file_mime_type in DOC_MIME:
|
elif file_mime_type in DOC_MIME:
|
||||||
return await self._extract_word_text(file_content)
|
return await self._extract_word_text(file_content)
|
||||||
|
elif file_mime_type in XLSX_MIME:
|
||||||
|
return await self._extract_xlsx_text(file_content)
|
||||||
|
elif file_mime_type in CSV_MIME:
|
||||||
|
return await self._extract_csv_text(file_content)
|
||||||
|
elif file_mime_type in JSON_MIME:
|
||||||
|
return await self._extract_json_text(file_content)
|
||||||
else:
|
else:
|
||||||
return f"[Unsupported file type: {file_mime_type}]"
|
return f"[Unsupported file type: {file_mime_type}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -602,7 +619,6 @@ class MultimodalService:
|
|||||||
async def _extract_word_text(file_content: bytes) -> str:
|
async def _extract_word_text(file_content: bytes) -> str:
|
||||||
"""提取 Word 文档文本"""
|
"""提取 Word 文档文本"""
|
||||||
try:
|
try:
|
||||||
# 使用 BytesIO 读取 Word 文档
|
|
||||||
word_file = io.BytesIO(file_content)
|
word_file = io.BytesIO(file_content)
|
||||||
doc = Document(word_file)
|
doc = Document(word_file)
|
||||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||||
@@ -611,6 +627,42 @@ class MultimodalService:
|
|||||||
logger.error(f"提取 Word 文本失败: {e}")
|
logger.error(f"提取 Word 文本失败: {e}")
|
||||||
return f"[Word 提取失败: {str(e)}]"
|
return f"[Word 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_xlsx_text(file_content: bytes) -> str:
|
||||||
|
"""提取 Excel 文本"""
|
||||||
|
try:
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||||
|
parts = []
|
||||||
|
for sheet in wb.worksheets:
|
||||||
|
parts.append(f"[Sheet: {sheet.title}]")
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||||
|
return '\n'.join(parts)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 Excel 文本失败: {e}")
|
||||||
|
return f"[Excel 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_csv_text(file_content: bytes) -> str:
|
||||||
|
"""提取 CSV 文本"""
|
||||||
|
try:
|
||||||
|
text = file_content.decode('utf-8-sig')
|
||||||
|
reader = csv.reader(io.StringIO(text))
|
||||||
|
return '\n'.join('\t'.join(row) for row in reader)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 CSV 文本失败: {e}")
|
||||||
|
return f"[CSV 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_json_text(file_content: bytes) -> str:
|
||||||
|
"""提取 JSON 文本"""
|
||||||
|
try:
|
||||||
|
data = json.loads(file_content.decode('utf-8'))
|
||||||
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 JSON 文本失败: {e}")
|
||||||
|
return f"[JSON 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||||
"""获取多模态服务实例(依赖注入)"""
|
"""获取多模态服务实例(依赖注入)"""
|
||||||
|
|||||||
Reference in New Issue
Block a user