feat(app):
1. Handling the storage of multimodal messages and adapting to the loading of historical messages for multi-round conversations; 2. Obtain the interface for retrieving the voice status of the reply; 3. File Information Retrieval Interface
This commit is contained in:
@@ -14,6 +14,9 @@ Routes:
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import httpx
|
||||||
|
import mimetypes
|
||||||
|
from urllib.parse import urlparse, unquote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
@@ -91,7 +94,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +175,6 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
# Get share and release info from share_token
|
# Get share and release info from share_token
|
||||||
service = ReleaseShareService(db)
|
service = ReleaseShareService(db)
|
||||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
# Get share object to access app_id
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
share = service.repo.get_by_share_token(share_data.share_token)
|
||||||
@@ -291,6 +293,101 @@ async def upload_file_with_share_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||||
|
async def get_file_info_by_url(
|
||||||
|
url: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file information by network URL (no authentication required).
|
||||||
|
|
||||||
|
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||||
|
Falls back to GET request if HEAD is not supported.
|
||||||
|
Returns file type, name, and size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The network URL of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file information.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File info by URL request: url={url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# Try HEAD request first
|
||||||
|
response = await client.head(url, follow_redirects=True)
|
||||||
|
|
||||||
|
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||||
|
response = await client.get(url, follow_redirects=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get file size from Content-Length header or actual content
|
||||||
|
file_size = response.headers.get("Content-Length")
|
||||||
|
if file_size:
|
||||||
|
file_size = int(file_size)
|
||||||
|
elif hasattr(response, 'content'):
|
||||||
|
file_size = len(response.content)
|
||||||
|
else:
|
||||||
|
file_size = None
|
||||||
|
|
||||||
|
# Get content type from Content-Type header
|
||||||
|
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||||
|
# Remove charset and other parameters from content type
|
||||||
|
content_type = content_type.split(';')[0].strip()
|
||||||
|
|
||||||
|
# Extract filename from Content-Disposition or URL
|
||||||
|
file_name = None
|
||||||
|
content_disposition = response.headers.get("Content-Disposition")
|
||||||
|
if content_disposition and "filename=" in content_disposition:
|
||||||
|
parts = content_disposition.split("filename=")
|
||||||
|
if len(parts) > 1:
|
||||||
|
file_name = parts[1].strip('"').strip("'")
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||||
|
|
||||||
|
# Extract file extension from filename
|
||||||
|
_, file_ext = os.path.splitext(file_name)
|
||||||
|
|
||||||
|
# If no extension found, infer from content type
|
||||||
|
if not file_ext:
|
||||||
|
ext = mimetypes.guess_extension(content_type)
|
||||||
|
if ext:
|
||||||
|
file_ext = ext
|
||||||
|
file_name = f"{file_name}{file_ext}"
|
||||||
|
|
||||||
|
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"url": url,
|
||||||
|
"file_name": file_name,
|
||||||
|
"file_ext": file_ext.lower() if file_ext else "",
|
||||||
|
"file_size": file_size,
|
||||||
|
"content_type": content_type,
|
||||||
|
},
|
||||||
|
msg="File information retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to retrieve file information: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -499,6 +596,51 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||||
|
async def get_permanent_file_url(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文件的永久公开 URL(无过期时间)。
|
||||||
|
|
||||||
|
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||||
|
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||||
|
"""
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||||
|
|
||||||
|
if file_metadata.status != "completed":
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||||
|
|
||||||
|
file_key = file_metadata.file_key
|
||||||
|
storage = storage_service.storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(storage, LocalStorage):
|
||||||
|
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||||
|
else:
|
||||||
|
url = await storage.get_permanent_url(file_key)
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Permanent URL not supported for current storage backend")
|
||||||
|
|
||||||
|
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||||
|
return success(
|
||||||
|
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||||
|
msg="Permanent file URL generated successfully"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -653,3 +795,44 @@ async def permanent_download_file(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to retrieve file: {str(e)}"
|
detail=f"Failed to retrieve file: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||||
|
async def get_file_status(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file upload/processing status (no authentication required).
|
||||||
|
|
||||||
|
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||||
|
Returns status: pending, completed, or failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The UUID of the file.
|
||||||
|
db: Database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file status and metadata.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File status request: file_id={file_id}")
|
||||||
|
|
||||||
|
# Query file metadata from database
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"file_id": str(file_id),
|
||||||
|
"status": file_metadata.status,
|
||||||
|
"file_name": file_metadata.file_name,
|
||||||
|
"file_size": file_metadata.file_size,
|
||||||
|
"content_type": file_metadata.content_type,
|
||||||
|
},
|
||||||
|
msg="File status retrieved successfully"
|
||||||
|
)
|
||||||
|
|||||||
@@ -119,14 +119,12 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
messages = self.conversation_service.get_messages(
|
history = self.conversation_service.get_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
limit=10
|
max_history=10,
|
||||||
|
current_provider=api_key_obj.provider,
|
||||||
|
current_is_omni=api_key_obj.is_omni
|
||||||
)
|
)
|
||||||
history = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
# 处理多模态文件
|
# 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
@@ -180,7 +178,8 @@ class AppChatService:
|
|||||||
|
|
||||||
# 构建用户消息内容(含多模态文件)
|
# 构建用户消息内容(含多模态文件)
|
||||||
human_meta = {
|
human_meta = {
|
||||||
"files": []
|
"files": [],
|
||||||
|
"history_files": {}
|
||||||
}
|
}
|
||||||
assistant_meta = {
|
assistant_meta = {
|
||||||
"model": api_key_obj.model_name,
|
"model": api_key_obj.model_name,
|
||||||
@@ -195,6 +194,13 @@ class AppChatService:
|
|||||||
"url": f.url
|
"url": f.url
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if processed_files:
|
||||||
|
human_meta["history_files"] = {
|
||||||
|
"content": processed_files,
|
||||||
|
"provider": api_key_obj.provider,
|
||||||
|
"is_omni": api_key_obj.is_omni
|
||||||
|
}
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
if audio_url:
|
if audio_url:
|
||||||
assistant_meta["audio_url"] = audio_url
|
assistant_meta["audio_url"] = audio_url
|
||||||
@@ -225,6 +231,7 @@ class AppChatService:
|
|||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def agnet_chat_stream(
|
async def agnet_chat_stream(
|
||||||
@@ -314,17 +321,12 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
history = []
|
history = self.conversation_service.get_conversation_history(
|
||||||
memory_config = {"enabled": True, 'max_history': 10}
|
conversation_id=conversation_id,
|
||||||
if memory_config.get("enabled"):
|
max_history=10,
|
||||||
messages = self.conversation_service.get_messages(
|
current_provider=api_key_obj.provider,
|
||||||
conversation_id=conversation_id,
|
current_is_omni=api_key_obj.is_omni
|
||||||
limit=memory_config.get("max_history", 10)
|
)
|
||||||
)
|
|
||||||
history = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
# 处理多模态文件
|
# 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
@@ -347,8 +349,14 @@ class AppChatService:
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
text_queue: asyncio.Queue = asyncio.Queue()
|
text_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
api_key_config = {
|
||||||
|
"model_name": api_key_obj.model_name,
|
||||||
|
"api_key": api_key_obj.api_key,
|
||||||
|
"api_base": api_key_obj.api_base,
|
||||||
|
"provider": api_key_obj.provider,
|
||||||
|
}
|
||||||
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||||
features_config, api_key_obj,
|
features_config, api_key_config,
|
||||||
text_queue=text_queue,
|
text_queue=text_queue,
|
||||||
tenant_id=tenant_id, workspace_id=workspace_id
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
@@ -378,7 +386,7 @@ class AppChatService:
|
|||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||||
|
|
||||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations)
|
||||||
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||||
@@ -388,11 +396,23 @@ class AppChatService:
|
|||||||
"api_base": api_key_obj.api_base}, {}
|
"api_base": api_key_obj.api_base}, {}
|
||||||
)
|
)
|
||||||
end_data["audio_url"] = stream_audio_url
|
end_data["audio_url"] = stream_audio_url
|
||||||
|
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||||
|
audio_status = "pending"
|
||||||
|
if tts_task is not None and tts_task.done():
|
||||||
|
# 任务已完成,检查是否有异常
|
||||||
|
try:
|
||||||
|
tts_task.result()
|
||||||
|
audio_status = "completed"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS任务异常: {e}")
|
||||||
|
audio_status = "failed"
|
||||||
|
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
human_meta = {
|
human_meta = {
|
||||||
"files":[]
|
"files":[],
|
||||||
|
"history_files": {}
|
||||||
}
|
}
|
||||||
assistant_meta = {
|
assistant_meta = {
|
||||||
"model": api_key_obj.model_name,
|
"model": api_key_obj.model_name,
|
||||||
@@ -402,11 +422,16 @@ class AppChatService:
|
|||||||
|
|
||||||
if files:
|
if files:
|
||||||
for f in files:
|
for f in files:
|
||||||
# url = await MultimodalService(self.db).get_file_url(f)
|
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": f.type,
|
"type": f.type,
|
||||||
"url": f.url
|
"url": f.url
|
||||||
})
|
})
|
||||||
|
if processed_files:
|
||||||
|
human_meta["history_files"] = {
|
||||||
|
"content": processed_files,
|
||||||
|
"provider": api_key_obj.provider,
|
||||||
|
"is_omni": api_key_obj.is_omni
|
||||||
|
}
|
||||||
|
|
||||||
if stream_audio_url:
|
if stream_audio_url:
|
||||||
assistant_meta["audio_url"] = stream_audio_url
|
assistant_meta["audio_url"] = stream_audio_url
|
||||||
|
|||||||
@@ -119,25 +119,27 @@ class ConversationService:
|
|||||||
|
|
||||||
def get_user_conversations(
|
def get_user_conversations(
|
||||||
self,
|
self,
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID,
|
||||||
) -> list[Conversation]:
|
page: int = 1,
|
||||||
|
page_size: int = 20
|
||||||
|
) -> tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
Retrieve recent conversations for a specific user
|
Retrieve recent conversations for a specific user with pagination.
|
||||||
|
|
||||||
This method delegates persistence logic to the repository layer and
|
|
||||||
applies service-level defaults (e.g. recent conversation limit).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (uuid.UUID): Unique identifier of the user.
|
user_id (uuid.UUID): Unique identifier of the user.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
page_size (int): Number of items per page. Defaults to 20.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Conversation]: A list of recent conversation entities.
|
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
|
||||||
"""
|
"""
|
||||||
conversations = self.conversation_repo.get_conversation_by_user_id(
|
conversations, total = self.conversation_repo.get_conversation_by_user_id(
|
||||||
user_id,
|
user_id,
|
||||||
limit=10
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return conversations
|
return conversations, total
|
||||||
|
|
||||||
def list_conversations(
|
def list_conversations(
|
||||||
self,
|
self,
|
||||||
@@ -270,7 +272,9 @@ class ConversationService:
|
|||||||
def get_conversation_history(
|
def get_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
max_history: Optional[int] = None
|
max_history: Optional[int] = None,
|
||||||
|
current_provider: Optional[str] = None,
|
||||||
|
current_is_omni: Optional[bool] = None
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve historical conversation messages formatted as dictionaries.
|
Retrieve historical conversation messages formatted as dictionaries.
|
||||||
@@ -278,6 +282,8 @@ class ConversationService:
|
|||||||
Args:
|
Args:
|
||||||
conversation_id (uuid.UUID): Conversation UUID.
|
conversation_id (uuid.UUID): Conversation UUID.
|
||||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||||
|
current_provider (Optional[str]): Current provider for file handling.
|
||||||
|
current_is_omni (Optional[bool]): Current omni flag for file handling.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||||
@@ -287,14 +293,30 @@ class ConversationService:
|
|||||||
limit=max_history
|
limit=max_history
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换为字典格式
|
history = []
|
||||||
history = [
|
for msg in messages:
|
||||||
{
|
msg_dict = {
|
||||||
"role": msg.role,
|
"role": msg.role,
|
||||||
"content": msg.content
|
"content": [{"type": "text", "text": msg.content}]
|
||||||
}
|
}
|
||||||
for msg in messages
|
|
||||||
]
|
# 处理用户消息中的多模态文件
|
||||||
|
if msg.role == "user" and msg.meta_data:
|
||||||
|
history_files = msg.meta_data.get("history_files", {})
|
||||||
|
|
||||||
|
if history_files and current_provider and current_is_omni is not None:
|
||||||
|
# 检查是否需要重新处理文件
|
||||||
|
stored_provider = history_files.get("provider")
|
||||||
|
stored_is_omni = history_files.get("is_omni")
|
||||||
|
|
||||||
|
# 如果provider或is_omni不匹配,需要重新处理
|
||||||
|
if stored_provider != current_provider or stored_is_omni != current_is_omni:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# provider和is_omni匹配,直接使用存储的内容
|
||||||
|
msg_dict["content"].extend(history_files.get("content"))
|
||||||
|
|
||||||
|
history.append(msg_dict)
|
||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
@@ -510,6 +532,7 @@ class ConversationService:
|
|||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
@@ -517,14 +540,17 @@ class ConversationService:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base
|
base_url=api_base,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_messages = self.get_conversation_history(
|
conversation_messages = self.get_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=20
|
max_history=20,
|
||||||
|
current_provider=provider,
|
||||||
|
current_is_omni=is_omni
|
||||||
)
|
)
|
||||||
if len(conversation_messages) == 0:
|
if len(conversation_messages) == 0:
|
||||||
return ConversationOut(
|
return ConversationOut(
|
||||||
|
|||||||
@@ -582,7 +582,9 @@ class AgentRunService:
|
|||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=10
|
max_history=10,
|
||||||
|
current_provider=api_key_config.get("provider"),
|
||||||
|
current_is_omni=api_key_config.get("is_omni", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
@@ -659,7 +661,10 @@ class AgentRunService:
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
files=files,
|
files=files,
|
||||||
audio_url=audio_url
|
processed_files=processed_files,
|
||||||
|
audio_url=audio_url,
|
||||||
|
provider=api_key_config.get("provider"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
@@ -676,6 +681,7 @@ class AgentRunService:
|
|||||||
) if not sub_agent else [],
|
) if not sub_agent else [],
|
||||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -818,7 +824,9 @@ class AgentRunService:
|
|||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=memory_config.get("max_history", 10)
|
max_history=memory_config.get("max_history", 10),
|
||||||
|
current_provider=api_key_config.get("provider"),
|
||||||
|
current_is_omni=api_key_config.get("is_omni", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
@@ -905,10 +913,13 @@ class AgentRunService:
|
|||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||||
},
|
},
|
||||||
files=files,
|
files=files,
|
||||||
audio_url=stream_audio_url
|
processed_files=processed_files,
|
||||||
|
audio_url=stream_audio_url,
|
||||||
|
provider=api_key_config.get("provider"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. 发送结束事件(包含 suggested_questions 和 tts)
|
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
|
||||||
end_data: Dict[str, Any] = {
|
end_data: Dict[str, Any] = {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
@@ -919,6 +930,17 @@ class AgentRunService:
|
|||||||
features_config, full_content, api_key_config, effective_params
|
features_config, full_content, api_key_config, effective_params
|
||||||
)
|
)
|
||||||
end_data["audio_url"] = stream_audio_url
|
end_data["audio_url"] = stream_audio_url
|
||||||
|
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||||
|
audio_status = "pending"
|
||||||
|
if tts_task is not None and tts_task.done():
|
||||||
|
# 任务已完成,检查是否有异常
|
||||||
|
try:
|
||||||
|
tts_task.result()
|
||||||
|
audio_status = "completed"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS任务异常: {e}")
|
||||||
|
audio_status = "failed"
|
||||||
|
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||||
end_data["citations"] = self._filter_citations(features_config, [])
|
end_data["citations"] = self._filter_citations(features_config, [])
|
||||||
yield self._format_sse_event("end", end_data)
|
yield self._format_sse_event("end", end_data)
|
||||||
|
|
||||||
@@ -1115,13 +1137,17 @@ class AgentRunService:
|
|||||||
async def _load_conversation_history(
|
async def _load_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
max_history: int = 10
|
max_history: int = 10,
|
||||||
|
current_provider: Optional[str] = None,
|
||||||
|
current_is_omni: Optional[bool] = None
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""加载会话历史消息
|
"""加载会话历史消息,并根据当前模型配置处理多模态文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: 会话ID
|
conversation_id: 会话ID
|
||||||
max_history: 最大历史消息数量
|
max_history: 最大历史消息数量
|
||||||
|
current_provider: 当前模型的provider
|
||||||
|
current_is_omni: 当前模型的is_omni
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict]: 历史消息列表
|
List[Dict]: 历史消息列表
|
||||||
@@ -1131,7 +1157,9 @@ class AgentRunService:
|
|||||||
conversation_service = ConversationService(self.db)
|
conversation_service = ConversationService(self.db)
|
||||||
history = conversation_service.get_conversation_history(
|
history = conversation_service.get_conversation_history(
|
||||||
conversation_id=uuid.UUID(conversation_id),
|
conversation_id=uuid.UUID(conversation_id),
|
||||||
max_history=max_history
|
max_history=max_history,
|
||||||
|
current_provider=current_provider,
|
||||||
|
current_is_omni=current_is_omni
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1159,7 +1187,10 @@ class AgentRunService:
|
|||||||
app_id: Optional[uuid.UUID] = None,
|
app_id: Optional[uuid.UUID] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
files: Optional[List[FileInput]] = None,
|
files: Optional[List[FileInput]] = None,
|
||||||
audio_url: Optional[str] = None
|
processed_files: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
audio_url: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
is_omni: Optional[bool] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||||
|
|
||||||
@@ -1170,6 +1201,11 @@ class AgentRunService:
|
|||||||
app_id: 应用ID(未使用,保留用于兼容性)
|
app_id: 应用ID(未使用,保留用于兼容性)
|
||||||
user_id: 用户ID(未使用,保留用于兼容性)
|
user_id: 用户ID(未使用,保留用于兼容性)
|
||||||
meta_data: token消耗
|
meta_data: token消耗
|
||||||
|
files: 原始文件输入
|
||||||
|
processed_files: 处理后的文件
|
||||||
|
audio_url: 音频URL
|
||||||
|
provider: 模型供应商
|
||||||
|
is_omni: 是否为全模态模型
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
@@ -1179,15 +1215,24 @@ class AgentRunService:
|
|||||||
|
|
||||||
# 保存消息(会话已经存在)
|
# 保存消息(会话已经存在)
|
||||||
human_meta = {
|
human_meta = {
|
||||||
"files": []
|
"files": [],
|
||||||
|
"history_files": {}
|
||||||
}
|
}
|
||||||
if files:
|
if files:
|
||||||
for f in files:
|
for f in files:
|
||||||
# url = await MultimodalService(self.db).get_file_url(f)
|
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": f.type,
|
"type": f.type,
|
||||||
"url": f.url
|
"url": f.url
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 保存 history_files,包含 provider 和 is_omni 信息
|
||||||
|
if processed_files:
|
||||||
|
human_meta["history_files"] = {
|
||||||
|
"content": processed_files,
|
||||||
|
"provider": provider,
|
||||||
|
"is_omni": is_omni
|
||||||
|
}
|
||||||
|
|
||||||
# 保存用户消息
|
# 保存用户消息
|
||||||
conversation_service.add_message(
|
conversation_service.add_message(
|
||||||
conversation_id=conv_uuid,
|
conversation_id=conv_uuid,
|
||||||
@@ -1413,8 +1458,9 @@ class AgentRunService:
|
|||||||
workspace_id: Optional[uuid.UUID] = None,
|
workspace_id: Optional[uuid.UUID] = None,
|
||||||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||||
"""文本流式输入并行合成音频。
|
"""文本流式输入并行合成音频。
|
||||||
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
|
||||||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||||
|
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
|
||||||
"""
|
"""
|
||||||
tts_config = features_config.get("text_to_speech", {})
|
tts_config = features_config.get("text_to_speech", {})
|
||||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||||
@@ -1801,6 +1847,7 @@ class AgentRunService:
|
|||||||
),
|
),
|
||||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||||
"audio_url": result.get("audio_url"),
|
"audio_url": result.get("audio_url"),
|
||||||
|
"audio_status": result.get("audio_status"),
|
||||||
"citations": result.get("citations", []),
|
"citations": result.get("citations", []),
|
||||||
"suggested_questions": result.get("suggested_questions", []),
|
"suggested_questions": result.get("suggested_questions", []),
|
||||||
"error": None
|
"error": None
|
||||||
@@ -1878,6 +1925,7 @@ class AgentRunService:
|
|||||||
"results": [{
|
"results": [{
|
||||||
**r,
|
**r,
|
||||||
"audio_url": r.get("audio_url"),
|
"audio_url": r.get("audio_url"),
|
||||||
|
"audio_status": r.get("audio_status"),
|
||||||
"citations": r.get("citations", []),
|
"citations": r.get("citations", []),
|
||||||
"suggested_questions": r.get("suggested_questions", []),
|
"suggested_questions": r.get("suggested_questions", []),
|
||||||
} for r in results],
|
} for r in results],
|
||||||
@@ -2009,6 +2057,7 @@ class AgentRunService:
|
|||||||
full_content = ""
|
full_content = ""
|
||||||
returned_conversation_id = model_conversation_id
|
returned_conversation_id = model_conversation_id
|
||||||
audio_url = None
|
audio_url = None
|
||||||
|
audio_status = None
|
||||||
citations = []
|
citations = []
|
||||||
suggested_questions = []
|
suggested_questions = []
|
||||||
|
|
||||||
@@ -2067,6 +2116,7 @@ class AgentRunService:
|
|||||||
# 从 end 事件中提取 features 输出字段
|
# 从 end 事件中提取 features 输出字段
|
||||||
if event_type == "end" and event_data:
|
if event_type == "end" and event_data:
|
||||||
audio_url = event_data.get("audio_url")
|
audio_url = event_data.get("audio_url")
|
||||||
|
audio_status = event_data.get("audio_status")
|
||||||
citations = event_data.get("citations", [])
|
citations = event_data.get("citations", [])
|
||||||
suggested_questions = event_data.get("suggested_questions", [])
|
suggested_questions = event_data.get("suggested_questions", [])
|
||||||
|
|
||||||
@@ -2096,6 +2146,7 @@ class AgentRunService:
|
|||||||
"message": full_content,
|
"message": full_content,
|
||||||
"elapsed_time": elapsed,
|
"elapsed_time": elapsed,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
|
"audio_status": audio_status,
|
||||||
"citations": citations,
|
"citations": citations,
|
||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"error": None
|
"error": None
|
||||||
@@ -2110,6 +2161,7 @@ class AgentRunService:
|
|||||||
"elapsed_time": elapsed,
|
"elapsed_time": elapsed,
|
||||||
"message_length": len(full_content),
|
"message_length": len(full_content),
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
|
"audio_status": audio_status,
|
||||||
"citations": citations,
|
"citations": citations,
|
||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"timestamp": time.time()
|
"timestamp": time.time()
|
||||||
@@ -2246,6 +2298,7 @@ class AgentRunService:
|
|||||||
"message": r.get("message"),
|
"message": r.get("message"),
|
||||||
"elapsed_time": r.get("elapsed_time", 0),
|
"elapsed_time": r.get("elapsed_time", 0),
|
||||||
"audio_url": r.get("audio_url"),
|
"audio_url": r.get("audio_url"),
|
||||||
|
"audio_status": r.get("audio_status"),
|
||||||
"citations": r.get("citations", []),
|
"citations": r.get("citations", []),
|
||||||
"suggested_questions": r.get("suggested_questions", []),
|
"suggested_questions": r.get("suggested_questions", []),
|
||||||
"error": r.get("error")
|
"error": r.get("error")
|
||||||
|
|||||||
Reference in New Issue
Block a user