fix(agent features):
1.Voice output is generated in a streaming manner. 2.Multimodal file storage type repair; 3.Adding features to the configuration of the sub-agents in the multi-agent system
This commit is contained in:
@@ -7,7 +7,7 @@ file operations across different storage backends.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
@@ -42,6 +42,26 @@ class StorageBackend(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Upload a file from an async byte stream.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file.
|
||||
stream: Async iterator yielding bytes chunks.
|
||||
content_type: Optional MIME type of the file.
|
||||
|
||||
Returns:
|
||||
Total bytes written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
from typing import AsyncIterator
|
||||
|
||||
from app.core.storage.base import StorageBackend
|
||||
from app.core.storage_exceptions import (
|
||||
@@ -179,6 +180,36 @@ class LocalStorage(StorageBackend):
|
||||
full_path = self._get_full_path(file_key)
|
||||
return full_path.exists()
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Upload a file from an async byte stream to the local file system.
|
||||
|
||||
Returns:
|
||||
Total bytes written.
|
||||
"""
|
||||
full_path = self._get_full_path(file_key)
|
||||
try:
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
total = 0
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
async for chunk in stream:
|
||||
await f.write(chunk)
|
||||
total += len(chunk)
|
||||
logger.info(f"File stream uploaded successfully: {file_key}")
|
||||
return total
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stream upload file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
"""
|
||||
Get an access URL for the file.
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on Aliyun Object
|
||||
Storage Service (OSS) using the oss2 SDK.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import oss2
|
||||
from oss2.exceptions import NoSuchKey, OssError
|
||||
@@ -125,10 +126,39 @@ class OSSStorage(StorageBackend):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||
buf = io.BytesIO()
|
||||
try:
|
||||
async for chunk in stream:
|
||||
buf.write(chunk)
|
||||
content = buf.getvalue()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||
return len(content)
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
Download a file from OSS.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on AWS S3
|
||||
using the boto3 SDK.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
||||
@@ -174,6 +175,62 @@ class S3Storage(StorageBackend):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Upload from async stream to S3 via multipart upload. Returns total bytes written."""
|
||||
extra_args = {"ContentType": content_type} if content_type else {}
|
||||
mpu = self.client.create_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key, **extra_args
|
||||
)
|
||||
upload_id = mpu["UploadId"]
|
||||
parts = []
|
||||
part_number = 1
|
||||
buf = io.BytesIO()
|
||||
total = 0
|
||||
min_part_size = 5 * 1024 * 1024 # S3 最小分片 5MB
|
||||
try:
|
||||
async for chunk in stream:
|
||||
buf.write(chunk)
|
||||
total += len(chunk)
|
||||
if buf.tell() >= min_part_size:
|
||||
buf.seek(0)
|
||||
resp = self.client.upload_part(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id, PartNumber=part_number, Body=buf.read()
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||
part_number += 1
|
||||
buf = io.BytesIO()
|
||||
# 上传剩余数据(最后一片可小于 5MB)
|
||||
remaining = buf.getvalue()
|
||||
if remaining:
|
||||
resp = self.client.upload_part(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id, PartNumber=part_number, Body=remaining
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||
self.client.complete_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id,
|
||||
MultipartUpload={"Parts": parts}
|
||||
)
|
||||
logger.info(f"File stream uploaded to S3 successfully: {file_key}")
|
||||
return total
|
||||
except Exception as e:
|
||||
self.client.abort_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key, UploadId=upload_id
|
||||
)
|
||||
logger.error(f"Failed to stream upload file to S3 {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to S3: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
Download a file from S3.
|
||||
|
||||
@@ -139,25 +139,25 @@ class FileUploadConfig(BaseModel):
|
||||
image_enabled: bool = Field(default=False)
|
||||
image_max_size_mb: int = Field(default=20)
|
||||
image_allowed_extensions: List[str] = Field(
|
||||
default=["png", "jpg", "jpeg", "gif", "webp"]
|
||||
default=["png", "jpg", "jpeg"]
|
||||
)
|
||||
# 语音文件:MP3/WAV/M4A/OGG/FLAC,最大 50MB
|
||||
audio_enabled: bool = Field(default=False)
|
||||
audio_max_size_mb: int = Field(default=50)
|
||||
audio_allowed_extensions: List[str] = Field(
|
||||
default=["mp3", "wav", "m4a", "ogg", "flac"]
|
||||
default=["mp3", "wav", "m4a"]
|
||||
)
|
||||
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
||||
document_enabled: bool = Field(default=False)
|
||||
document_max_size_mb: int = Field(default=100)
|
||||
document_allowed_extensions: List[str] = Field(
|
||||
default=["pdf", "docx", "xlsx", "txt", "csv", "json"]
|
||||
default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"]
|
||||
)
|
||||
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
||||
video_enabled: bool = Field(default=False)
|
||||
video_max_size_mb: int = Field(default=500)
|
||||
video_allowed_extensions: List[str] = Field(
|
||||
default=["mp4", "mov", "avi", "webm"]
|
||||
default=["mp4", "mov"]
|
||||
)
|
||||
# 最大文件数量
|
||||
max_file_count: int = Field(default=5, ge=1, le=20)
|
||||
|
||||
@@ -191,7 +191,7 @@ class AppChatService:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": FileType.IMAGE,
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
@@ -342,9 +342,17 @@ class AppChatService:
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
|
||||
text_queue: asyncio.Queue = asyncio.Queue()
|
||||
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||
features_config, api_key_obj,
|
||||
text_queue=text_queue,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -354,17 +362,20 @@ class AppChatService:
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
if tts_task is not None:
|
||||
await text_queue.put(chunk)
|
||||
|
||||
if tts_task is not None:
|
||||
await text_queue.put(None)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||
@@ -376,12 +387,6 @@ class AppChatService:
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
stream_audio_url = await self.agent_service._generate_tts(
|
||||
features_config, full_content,
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||
|
||||
@@ -399,7 +404,7 @@ class AppChatService:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": FileType.IMAGE,
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
|
||||
@@ -852,9 +852,18 @@ class AgentRunService:
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||
|
||||
# 9. 流式调用 Agent(支持多模态)
|
||||
# 9. 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
|
||||
# 启动流式 TTS(文本边输出边合成)
|
||||
text_queue: asyncio.Queue = asyncio.Queue()
|
||||
stream_audio_url, tts_task = await self._generate_tts_streaming(
|
||||
features_config, api_key_config,
|
||||
text_queue=text_queue,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
) if not sub_agent else (None, None)
|
||||
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -864,31 +873,25 @@ class AgentRunService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield self._format_sse_event("message", {
|
||||
"content": chunk
|
||||
})
|
||||
yield self._format_sse_event("message", {"content": chunk})
|
||||
if tts_task is not None:
|
||||
await text_queue.put(chunk)
|
||||
|
||||
# 文本结束,通知 TTS
|
||||
if tts_task is not None:
|
||||
await text_queue.put(None)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
if sub_agent:
|
||||
yield self._format_sse_event("sub_usage", {
|
||||
"total_tokens": total_tokens
|
||||
})
|
||||
|
||||
# 10. 生成 audio_url(在保存消息前生成,以便一并存入 meta_data)
|
||||
stream_audio_url = await self._generate_tts(
|
||||
features_config, full_content, api_key_config,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
) if not sub_agent else None
|
||||
yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens})
|
||||
|
||||
# 11. 保存会话消息
|
||||
if not sub_agent:
|
||||
@@ -1182,7 +1185,7 @@ class AgentRunService:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": FileType.IMAGE,
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
# 保存用户消息
|
||||
@@ -1317,125 +1320,345 @@ class AgentRunService:
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据 text_to_speech 配置生成语音,上传到存储并返回 URL"""
|
||||
"""先注册文件元数据并返回 audio_url,再后台流式写入音频内容"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
return None
|
||||
if not text or not text.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.services.file_storage_service import FileStorageService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||||
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
api_key = api_key_config.get("api_key")
|
||||
api_base = api_key_config.get("api_base")
|
||||
voice = tts_config.get("voice")
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
api_key = api_key_config.get("api_key")
|
||||
api_base = api_key_config.get("api_base")
|
||||
voice = tts_config.get("voice")
|
||||
file_ext, content_type = ".mp3", "audio/mpeg"
|
||||
|
||||
if provider == "dashscope":
|
||||
audio_bytes, file_ext, content_type = await self._tts_dashscope(
|
||||
api_key=api_key,
|
||||
text=text,
|
||||
voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀
|
||||
tts_config=tts_config,
|
||||
)
|
||||
else:
|
||||
# OpenAI 兼容接口(openai / xinference / gpustack 等)
|
||||
audio_bytes, file_ext, content_type = await self._tts_openai(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
text=text,
|
||||
voice=voice or "alloy",
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||
|
||||
# 先写入 pending 状态的元数据,立即返回 URL
|
||||
db_file = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=f"tts_{file_id}{file_ext}",
|
||||
file_ext=file_ext,
|
||||
file_size=0,
|
||||
content_type=content_type,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(db_file)
|
||||
self.db.commit()
|
||||
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
# 后台任务:流式生成并写入存储,完成后更新状态
|
||||
async def _stream_to_storage():
|
||||
try:
|
||||
storage_service = FileStorageService()
|
||||
if provider == "dashscope":
|
||||
stream = self._tts_dashscope_stream(
|
||||
api_key=api_key,
|
||||
text=text,
|
||||
voice=voice or "longxiaochun",
|
||||
tts_config=tts_config,
|
||||
)
|
||||
else:
|
||||
stream = self._tts_openai_stream(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
text=text,
|
||||
voice=voice or "alloy",
|
||||
)
|
||||
|
||||
total_size = await storage_service.upload_stream(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
stream=stream,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
storage_service = FileStorageService()
|
||||
file_id = uuid.uuid4()
|
||||
file_key = await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=audio_bytes,
|
||||
content_type=content_type,
|
||||
)
|
||||
# 更新元数据状态
|
||||
with get_db_context() as bg_db:
|
||||
record = bg_db.get(FileMetadata, file_id)
|
||||
if record:
|
||||
record.status = "completed"
|
||||
record.file_size = total_size
|
||||
bg_db.commit()
|
||||
logger.debug(f"TTS 流式写入完成,provider={provider}, file_key={file_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS 流式写入失败: {e}")
|
||||
with get_db_context() as bg_db:
|
||||
record = bg_db.get(FileMetadata, file_id)
|
||||
if record:
|
||||
record.status = "failed"
|
||||
bg_db.commit()
|
||||
|
||||
# 保存文件元数据到数据库
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
db_file = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=f"tts_{file_id}{file_ext}",
|
||||
file_ext=file_ext,
|
||||
file_size=len(audio_bytes),
|
||||
content_type=content_type,
|
||||
status="completed",
|
||||
)
|
||||
self.db.add(db_file)
|
||||
self.db.commit()
|
||||
asyncio.create_task(_stream_to_storage())
|
||||
return audio_url
|
||||
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
logger.debug(f"TTS 生成成功,provider={provider}, file_key={file_key}")
|
||||
return audio_url
|
||||
async def _generate_tts_streaming(
|
||||
self,
|
||||
features_config: Dict[str, Any],
|
||||
api_key_config: Dict[str, Any],
|
||||
text_queue: asyncio.Queue,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||
"""文本流式输入并行合成音频。
|
||||
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
||||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||
"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS 生成失败: {e}")
|
||||
return None
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||||
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
api_key = api_key_config.get("api_key")
|
||||
api_base = api_key_config.get("api_base")
|
||||
voice = tts_config.get("voice")
|
||||
file_ext, content_type = ".mp3", "audio/mpeg"
|
||||
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||
|
||||
db_file = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=f"tts_{file_id}{file_ext}",
|
||||
file_ext=file_ext,
|
||||
file_size=0,
|
||||
content_type=content_type,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(db_file)
|
||||
self.db.commit()
|
||||
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
storage_service = FileStorageService()
|
||||
if provider == "dashscope":
|
||||
audio_stream = self._tts_dashscope_stream_from_queue(
|
||||
api_key=api_key,
|
||||
voice=voice or "longxiaochun",
|
||||
tts_config=tts_config,
|
||||
text_queue=text_queue,
|
||||
)
|
||||
else:
|
||||
audio_stream = self._tts_openai_stream_from_queue(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
voice=voice or "alloy",
|
||||
text_queue=text_queue,
|
||||
)
|
||||
total_size = await storage_service.upload_stream(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
stream=audio_stream,
|
||||
content_type=content_type,
|
||||
)
|
||||
with get_db_context() as bg_db:
|
||||
record = bg_db.get(FileMetadata, file_id)
|
||||
if record:
|
||||
record.status = "completed"
|
||||
record.file_size = total_size
|
||||
bg_db.commit()
|
||||
logger.debug(f"TTS 流式合成完成,provider={provider}, file_key={file_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS 流式合成失败: {e}")
|
||||
with get_db_context() as bg_db:
|
||||
record = bg_db.get(FileMetadata, file_id)
|
||||
if record:
|
||||
record.status = "failed"
|
||||
bg_db.commit()
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
return audio_url, task
|
||||
|
||||
@staticmethod
|
||||
async def _tts_openai(
|
||||
async def _tts_openai_stream_from_queue(
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
voice: str,
|
||||
text_queue: asyncio.Queue,
|
||||
):
|
||||
"""OpenAI TTS:收集全部文本后流式合成(OpenAI 不支持增量输入)"""
|
||||
from openai import AsyncOpenAI
|
||||
# 收集全部文本(此时文本流已并行输出,等待时间短)
|
||||
parts = []
|
||||
while True:
|
||||
chunk = await text_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
parts.append(chunk)
|
||||
full_text = "".join(parts)
|
||||
if not full_text.strip():
|
||||
return
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model="tts-1",
|
||||
voice=voice,
|
||||
input=full_text[:4096],
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=4096):
|
||||
yield chunk
|
||||
|
||||
@staticmethod
|
||||
async def _tts_dashscope_stream_from_queue(
|
||||
api_key: str,
|
||||
voice: str,
|
||||
tts_config: Dict[str, Any],
|
||||
text_queue: asyncio.Queue,
|
||||
):
|
||||
"""DashScope TTS:文本流式输入,实现真正并行合成"""
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||||
|
||||
model = tts_config.get("model") or "cosyvoice-v2"
|
||||
is_v2 = model.endswith("-v2")
|
||||
if is_v2 and not voice.endswith("_v2"):
|
||||
voice = voice + "_v2"
|
||||
elif not is_v2 and voice.endswith("_v2"):
|
||||
voice = voice[:-3]
|
||||
|
||||
audio_queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class _Callback(ResultCallback):
|
||||
def on_data(self, data: bytes):
|
||||
if data:
|
||||
loop.call_soon_threadsafe(audio_queue.put_nowait, data)
|
||||
def on_complete(self):
|
||||
loop.call_soon_threadsafe(audio_queue.put_nowait, None)
|
||||
def on_error(self, message):
|
||||
loop.call_soon_threadsafe(audio_queue.put_nowait, RuntimeError(str(message)))
|
||||
def on_open(self): pass
|
||||
def on_close(self): pass
|
||||
|
||||
dashscope.api_key = api_key
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=model,
|
||||
voice=voice,
|
||||
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||||
callback=_Callback(),
|
||||
)
|
||||
|
||||
async def _feed_text():
|
||||
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
|
||||
import re
|
||||
buf = ""
|
||||
sentence_end = re.compile(r'[\u3002\uff01\uff1f\.!?\n]')
|
||||
while True:
|
||||
chunk = await text_queue.get()
|
||||
if chunk is None:
|
||||
if buf.strip():
|
||||
await asyncio.to_thread(synthesizer.streaming_call, buf)
|
||||
await asyncio.to_thread(synthesizer.streaming_complete)
|
||||
break
|
||||
buf += chunk
|
||||
# 按句子切分喂入
|
||||
while sentence_end.search(buf):
|
||||
m = sentence_end.search(buf)
|
||||
sentence = buf[:m.end()]
|
||||
buf = buf[m.end():]
|
||||
await asyncio.to_thread(synthesizer.streaming_call, sentence)
|
||||
|
||||
asyncio.create_task(_feed_text())
|
||||
|
||||
while True:
|
||||
item = await audio_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
@staticmethod
|
||||
async def _tts_openai_stream(
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
text: str,
|
||||
voice: str,
|
||||
) -> tuple:
|
||||
"""OpenAI 兼容 TTS,返回 (audio_bytes, file_ext, content_type)"""
|
||||
):
|
||||
"""OpenAI 兼容 TTS 流式生成,yield bytes chunks"""
|
||||
from openai import AsyncOpenAI
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
response = await client.audio.speech.create(
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model="tts-1",
|
||||
voice=voice,
|
||||
input=text[:4096],
|
||||
)
|
||||
return response.content, ".mp3", "audio/mpeg"
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=4096):
|
||||
yield chunk
|
||||
|
||||
@staticmethod
|
||||
async def _tts_dashscope(
|
||||
async def _tts_dashscope_stream(
|
||||
api_key: str,
|
||||
text: str,
|
||||
voice: str,
|
||||
tts_config: Dict[str, Any],
|
||||
) -> tuple:
|
||||
"""DashScope CosyVoice TTS,返回 (audio_bytes, file_ext, content_type)"""
|
||||
):
|
||||
"""DashScope TTS 流式生成,yield bytes chunks"""
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||||
|
||||
model = tts_config.get("model") or "cosyvoice-v2"
|
||||
is_v2 = model.endswith("-v2")
|
||||
|
||||
# cosyvoice-v2 音色名带 _v2 后缀,v1 不带
|
||||
# 如果用户传入的 voice 不匹配当前模型版本,自动修正
|
||||
if is_v2 and not voice.endswith("_v2"):
|
||||
voice = voice + "_v2"
|
||||
elif not is_v2 and voice.endswith("_v2"):
|
||||
voice = voice[:-3] # 去掉 _v2
|
||||
voice = voice[:-3]
|
||||
|
||||
def _sync_call() -> bytes:
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class _Callback(ResultCallback):
|
||||
def on_data(self, data: bytes):
|
||||
if data:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, data)
|
||||
def on_complete(self):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||
def on_error(self, message):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, RuntimeError(str(message)))
|
||||
def on_open(self): pass
|
||||
def on_close(self): pass
|
||||
|
||||
def _sync_stream():
|
||||
dashscope.api_key = api_key
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=model,
|
||||
voice=voice,
|
||||
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||||
callback=_Callback(),
|
||||
)
|
||||
audio = synthesizer.call(text[:4096])
|
||||
if audio is None:
|
||||
raise RuntimeError("DashScope TTS 返回空音频")
|
||||
return audio
|
||||
synthesizer.streaming_call(text[:4096])
|
||||
synthesizer.streaming_complete()
|
||||
|
||||
audio_bytes = await asyncio.to_thread(_sync_call)
|
||||
return audio_bytes, ".mp3", "audio/mpeg"
|
||||
asyncio.create_task(asyncio.to_thread(_sync_stream))
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def _replace_variables(
|
||||
self,
|
||||
|
||||
@@ -9,7 +9,7 @@ and error handling.
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from app.core.storage import StorageFactory, StorageBackend
|
||||
from app.core.storage_exceptions import (
|
||||
@@ -162,6 +162,31 @@ class FileStorageService:
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Upload a file from an async byte stream.
|
||||
|
||||
Returns:
|
||||
Total bytes written.
|
||||
"""
|
||||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||
logger.info(f"Starting stream upload: file_key={file_key}, content_type={content_type}")
|
||||
try:
|
||||
total = await self.storage.upload_stream(file_key, stream, content_type)
|
||||
logger.info(f"Stream upload successful: file_key={file_key}, size={total} bytes")
|
||||
return total
|
||||
except Exception as e:
|
||||
logger.error(f"Stream upload failed: file_key={file_key}, error={str(e)}")
|
||||
raise
|
||||
|
||||
async def download_file(self, file_key: str) -> bytes:
|
||||
"""
|
||||
Download a file from storage.
|
||||
|
||||
@@ -1638,6 +1638,7 @@ class MultiAgentOrchestrator:
|
||||
self.variables = config_data.get("variables", [])
|
||||
self.tools = config_data.get("tools", {})
|
||||
self.skills = config_data.get("skills", {})
|
||||
self.features = config_data.get("features", {})
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
return AgentConfigProxy(release, app, config_data)
|
||||
|
||||
Reference in New Issue
Block a user