From a3428c2435271c0ffab64aa2d38f7465a92c9432 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 23 Mar 2026 17:04:30 +0800 Subject: [PATCH] feat(app): 1. Handling the storage of multimodal messages and adapting to the loading of historical messages for multi-round conversations; 2. Obtain the interface for retrieving the voice status of the reply; 3. File Information Retrieval Interface --- .../controllers/file_storage_controller.py | 187 +++++++++++++++++- api/app/services/app_chat_service.py | 69 ++++--- api/app/services/conversation_service.py | 64 ++++-- api/app/services/draft_run_service.py | 77 ++++++-- 4 files changed, 342 insertions(+), 55 deletions(-) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index ff284f39..14962a72 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -14,6 +14,9 @@ Routes: import os import uuid from typing import Any +import httpx +import mimetypes +from urllib.parse import urlparse, unquote from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse @@ -91,7 +94,7 @@ async def upload_file( if file_size > settings.MAX_FILE_SIZE: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit" ) @@ -172,7 +175,6 @@ async def upload_file_with_share_token( # Get share and release info from share_token service = ReleaseShareService(db) - share_info = service.get_shared_release_info(share_token=share_data.share_token) # Get share object to access app_id share = service.repo.get_by_share_token(share_data.share_token) @@ -291,6 +293,101 @@ async def upload_file_with_share_token( ) +@router.get("/files/info-by-url", response_model=ApiResponse) +async def get_file_info_by_url( + url: str, +): + """ + Get file information by network URL (no authentication required). + + Fetches file metadata from a remote URL via HTTP HEAD request. + Falls back to GET request if HEAD is not supported. + Returns file type, name, and size. + + Args: + url: The network URL of the file. + + Returns: + ApiResponse with file information. + """ + api_logger.info(f"File info by URL request: url={url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Try HEAD request first + response = await client.head(url, follow_redirects=True) + + # If HEAD fails, try GET request (some servers don't support HEAD) + if response.status_code != 200: + api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request") + response = await client.get(url, follow_redirects=True) + + if response.status_code != 200: + api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access file: HTTP {response.status_code}" + ) + + # Get file size from Content-Length header or actual content + file_size = response.headers.get("Content-Length") + if file_size: + file_size = int(file_size) + elif hasattr(response, 'content'): + file_size = len(response.content) + else: + file_size = None + + # Get content type from Content-Type header + content_type = response.headers.get("Content-Type", "application/octet-stream") + # Remove charset and other parameters from content type + content_type = content_type.split(';')[0].strip() + + # Extract filename from Content-Disposition or URL + file_name = None + content_disposition = response.headers.get("Content-Disposition") + if content_disposition and "filename=" in content_disposition: + parts = content_disposition.split("filename=") + if len(parts) > 1: + file_name = parts[1].strip('"').strip("'") + + if not file_name: + parsed_url = urlparse(url) + file_name = unquote(os.path.basename(parsed_url.path)) or "unknown" + + # Extract file extension from filename + _, file_ext = os.path.splitext(file_name) + + # If no extension found, infer from content type + if not file_ext: + ext = mimetypes.guess_extension(content_type) + if ext: + file_ext = ext + file_name = f"{file_name}{file_ext}" + + api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}") + + return success( + data={ + "url": url, + "file_name": file_name, + "file_ext": file_ext.lower() if file_ext else "", + "file_size": file_size, + "content_type": content_type, + }, + msg="File information retrieved successfully" + ) + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Unexpected error: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve file information: {str(e)}" + ) + + @router.get("/files/{file_id}", response_model=Any) async def download_file( request: Request, @@ -499,6 +596,51 @@ async def get_file_url( ) +@router.get("/files/{file_id}/public-url", response_model=ApiResponse) +async def get_permanent_file_url( + file_id: uuid.UUID, + db: Session = Depends(get_db), + storage_service: FileStorageService = Depends(get_file_storage_service), +): + """ + 获取文件的永久公开 URL(无过期时间)。 + + - 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置) + - 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限) + """ + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist") + + if file_metadata.status != "completed": + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File upload not completed, status: {file_metadata.status}") + + file_key = file_metadata.file_key + storage = storage_service.storage + + try: + if isinstance(storage, LocalStorage): + url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}" + else: + url = await storage.get_permanent_url(file_key) + if not url: + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Permanent URL not supported for current storage backend") + + api_logger.info(f"Generated permanent URL: file_id={file_id}") + return success( + data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name}, + msg="Permanent file URL generated successfully" + ) + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to generate permanent URL: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate permanent URL: {str(e)}") + + @router.get("/public/{file_id}", response_model=Any) async def public_download_file( request: Request, @@ -653,3 +795,44 @@ async def permanent_download_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve file: {str(e)}" ) + + +@router.get("/files/{file_id}/status", response_model=ApiResponse) +async def get_file_status( + file_id: uuid.UUID, + db: Session = Depends(get_db), +): + """ + Get file upload/processing status (no authentication required). + + This endpoint is used to check if a file (e.g., TTS audio) is ready. + Returns status: pending, completed, or failed. + + Args: + file_id: The UUID of the file. + db: Database session. + + Returns: + ApiResponse with file status and metadata. + """ + api_logger.info(f"File status request: file_id={file_id}") + + # Query file metadata from database + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + api_logger.warning(f"File not found in database: file_id={file_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The file does not exist" + ) + + return success( + data={ + "file_id": str(file_id), + "status": file_metadata.status, + "file_name": file_metadata.file_name, + "file_size": file_metadata.file_size, + "content_type": file_metadata.content_type, + }, + msg="File status retrieved successfully" + ) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..94a6dddb 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -119,14 +119,12 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] # 处理多模态文件 processed_files = None @@ -180,7 +178,8 @@ class AppChatService: # 构建用户消息内容(含多模态文件) human_meta = { - "files": [] + "files": [], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -195,6 +194,13 @@ class AppChatService: "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } + # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url @@ -225,6 +231,7 @@ class AppChatService: "suggested_questions": suggested_questions, "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } async def agnet_chat_stream( @@ -314,17 +321,12 @@ class AppChatService: ) # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + history = self.conversation_service.get_conversation_history( + conversation_id=conversation_id, + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni + ) # 处理多模态文件 processed_files = None @@ -347,8 +349,14 @@ class AppChatService: total_tokens = 0 text_queue: asyncio.Queue = asyncio.Queue() + api_key_config = { + "model_name": api_key_obj.model_name, + "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, + "provider": api_key_obj.provider, + } stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming( - features_config, api_key_obj, + features_config, api_key_config, text_queue=text_queue, tenant_id=tenant_id, workspace_id=workspace_id ) @@ -378,7 +386,7 @@ class AppChatService: elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件(包含 suggested_questions、tts、citations) + # 发送结束事件(包含 suggested_questions、tts、audio_status、citations) end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} sq_config = features_config.get("suggested_questions_after_answer", {}) if isinstance(sq_config, dict) and sq_config.get("enabled"): @@ -388,11 +396,23 @@ class AppChatService: "api_base": api_key_obj.api_base}, {} ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self.agent_service._filter_citations(features_config, []) # 保存消息 human_meta = { - "files":[] + "files":[], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -402,11 +422,16 @@ class AppChatService: if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index aff5f533..d7bb3595 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -119,25 +119,27 @@ class ConversationService: def get_user_conversations( self, - user_id: uuid.UUID - ) -> list[Conversation]: + user_id: uuid.UUID, + page: int = 1, + page_size: int = 20 + ) -> tuple[list[Conversation], int]: """ - Retrieve recent conversations for a specific user - - This method delegates persistence logic to the repository layer and - applies service-level defaults (e.g. recent conversation limit). + Retrieve recent conversations for a specific user with pagination. Args: user_id (uuid.UUID): Unique identifier of the user. + page (int): Page number (1-based). Defaults to 1. + page_size (int): Number of items per page. Defaults to 20. Returns: - list[Conversation]: A list of recent conversation entities. + tuple[list[Conversation], int]: A list of recent conversation entities and total count. """ - conversations = self.conversation_repo.get_conversation_by_user_id( + conversations, total = self.conversation_repo.get_conversation_by_user_id( user_id, - limit=10 + page=page, + page_size=page_size ) - return conversations + return conversations, total def list_conversations( self, @@ -270,7 +272,9 @@ class ConversationService: def get_conversation_history( self, conversation_id: uuid.UUID, - max_history: Optional[int] = None + max_history: Optional[int] = None, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -278,6 +282,8 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. + current_provider (Optional[str]): Current provider for file handling. + current_is_omni (Optional[bool]): Current omni flag for file handling. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -287,14 +293,30 @@ class ConversationService: limit=max_history ) - # 转换为字典格式 - history = [ - { + history = [] + for msg in messages: + msg_dict = { "role": msg.role, - "content": msg.content + "content": [{"type": "text", "text": msg.content}] } - for msg in messages - ] + + # 处理用户消息中的多模态文件 + if msg.role == "user" and msg.meta_data: + history_files = msg.meta_data.get("history_files", {}) + + if history_files and current_provider and current_is_omni is not None: + # 检查是否需要重新处理文件 + stored_provider = history_files.get("provider") + stored_is_omni = history_files.get("is_omni") + + # 如果provider或is_omni不匹配,需要重新处理 + if stored_provider != current_provider or stored_is_omni != current_is_omni: + continue + + # provider和is_omni匹配,直接使用存储的内容 + msg_dict["content"].extend(history_files.get("content")) + + history.append(msg_dict) return history @@ -510,6 +532,7 @@ class ConversationService: provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -517,14 +540,17 @@ class ConversationService: model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) conversation_messages = self.get_conversation_history( conversation_id=conversation_id, - max_history=20 + max_history=20, + current_provider=provider, + current_is_omni=is_omni ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..13243a92 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -582,7 +582,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=10 + max_history=10, + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -659,7 +661,10 @@ class AgentRunService: }) }, files=files, - audio_url=audio_url + processed_files=processed_files, + audio_url=audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) response = { @@ -676,6 +681,7 @@ class AgentRunService: ) if not sub_agent else [], "citations": self._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } logger.info( @@ -818,7 +824,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=memory_config.get("max_history", 10) + max_history=memory_config.get("max_history", 10), + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -905,10 +913,13 @@ class AgentRunService: "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} }, files=files, - audio_url=stream_audio_url + processed_files=processed_files, + audio_url=stream_audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) - # 12. 发送结束事件(包含 suggested_questions 和 tts) + # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, @@ -919,6 +930,17 @@ class AgentRunService: features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self._filter_citations(features_config, []) yield self._format_sse_event("end", end_data) @@ -1115,13 +1137,17 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, - max_history: int = 10 + max_history: int = 10, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: - """加载会话历史消息 + """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 + current_provider: 当前模型的provider + current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 @@ -1131,7 +1157,9 @@ class AgentRunService: conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), - max_history=max_history + max_history=max_history, + current_provider=current_provider, + current_is_omni=current_is_omni ) logger.debug( @@ -1159,7 +1187,10 @@ class AgentRunService: app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, - audio_url: Optional[str] = None + processed_files: Optional[List[Dict[str, Any]]] = None, + audio_url: Optional[str] = None, + provider: Optional[str] = None, + is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) @@ -1170,6 +1201,11 @@ class AgentRunService: app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 + files: 原始文件输入 + processed_files: 处理后的文件 + audio_url: 音频URL + provider: 模型供应商 + is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService @@ -1179,15 +1215,24 @@ class AgentRunService: # 保存消息(会话已经存在) human_meta = { - "files": [] + "files": [], + "history_files": {} } if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + + # 保存 history_files,包含 provider 和 is_omni 信息 + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": provider, + "is_omni": is_omni + } + # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, @@ -1413,8 +1458,9 @@ class AgentRunService: workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 - 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。 + 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 + 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): @@ -1801,6 +1847,7 @@ class AgentRunService: ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), + "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None @@ -1878,6 +1925,7 @@ class AgentRunService: "results": [{ **r, "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], @@ -2009,6 +2057,7 @@ class AgentRunService: full_content = "" returned_conversation_id = model_conversation_id audio_url = None + audio_status = None citations = [] suggested_questions = [] @@ -2067,6 +2116,7 @@ class AgentRunService: # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") + audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) @@ -2096,6 +2146,7 @@ class AgentRunService: "message": full_content, "elapsed_time": elapsed, "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None @@ -2110,6 +2161,7 @@ class AgentRunService: "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() @@ -2246,6 +2298,7 @@ class AgentRunService: "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error")