mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Implement temporary file management on openai endpoint, whole file downloads
This commit is contained in:
parent
355ec54f78
commit
946e322242
10 changed files with 346 additions and 64 deletions
|
@ -32,6 +32,11 @@ class Settings(BaseSettings):
|
||||||
cors_origins: list[str] = ["*"] # CORS origins for web player
|
cors_origins: list[str] = ["*"] # CORS origins for web player
|
||||||
cors_enabled: bool = True # Whether to enable CORS
|
cors_enabled: bool = True # Whether to enable CORS
|
||||||
|
|
||||||
|
# Temp File Settings
|
||||||
|
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
|
||||||
|
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
|
||||||
|
temp_file_max_age_hours: int = 1 # Remove temp files older than 1 hour
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
|
@ -337,4 +337,78 @@ async def get_content_type(path: str) -> str:
|
||||||
|
|
||||||
async def verify_model_path(model_path: str) -> bool:
|
async def verify_model_path(model_path: str) -> bool:
|
||||||
"""Verify model file exists at path."""
|
"""Verify model file exists at path."""
|
||||||
return await aiofiles.os.path.exists(model_path)
|
return await aiofiles.os.path.exists(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_temp_files() -> None:
|
||||||
|
"""Clean up old temp files on startup"""
|
||||||
|
try:
|
||||||
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
|
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
|
for entry in entries:
|
||||||
|
if entry.is_file():
|
||||||
|
stat = await aiofiles.os.stat(entry.path)
|
||||||
|
max_age = stat.st_mtime + (settings.temp_file_max_age_hours * 3600)
|
||||||
|
if max_age < stat.st_mtime:
|
||||||
|
try:
|
||||||
|
await aiofiles.os.remove(entry.path)
|
||||||
|
logger.info(f"Cleaned up old temp file: {entry.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning temp files: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_temp_file_path(filename: str) -> str:
|
||||||
|
"""Get path to temporary audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of temp file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to temp file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If temp directory does not exist
|
||||||
|
"""
|
||||||
|
temp_path = os.path.join(settings.temp_file_dir, filename)
|
||||||
|
|
||||||
|
# Ensure temp directory exists
|
||||||
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
|
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||||
|
|
||||||
|
return temp_path
|
||||||
|
|
||||||
|
|
||||||
|
async def list_temp_files() -> List[str]:
|
||||||
|
"""List temporary audio files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of temp file names
|
||||||
|
"""
|
||||||
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
|
return [entry.name for entry in entries if entry.is_file()]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_temp_dir_size() -> int:
|
||||||
|
"""Get total size of temp directory in bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Size in bytes
|
||||||
|
"""
|
||||||
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
|
for entry in entries:
|
||||||
|
if entry.is_file():
|
||||||
|
stat = await aiofiles.os.stat(entry.path)
|
||||||
|
total += stat.st_size
|
||||||
|
return total
|
|
@ -48,6 +48,10 @@ async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for model initialization"""
|
"""Lifespan context manager for model initialization"""
|
||||||
from .inference.model_manager import get_manager
|
from .inference.model_manager import get_manager
|
||||||
from .inference.voice_manager import get_manager as get_voice_manager
|
from .inference.voice_manager import get_manager as get_voice_manager
|
||||||
|
from .core.paths import cleanup_temp_files
|
||||||
|
|
||||||
|
# Clean old temp files on startup
|
||||||
|
await cleanup_temp_files()
|
||||||
|
|
||||||
logger.info("Loading TTS model and voice packs...")
|
logger.info("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
@ -179,36 +179,59 @@ async def create_speech(
|
||||||
# Create generator but don't start it yet
|
# Create generator but don't start it yet
|
||||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||||
|
|
||||||
# Test the generator by attempting to get first chunk
|
# If download link requested, wrap generator with temp file writer
|
||||||
try:
|
if request.return_download_link:
|
||||||
first_chunk = await anext(generator)
|
from ..services.temp_manager import TempFileWriter
|
||||||
except StopAsyncIteration:
|
|
||||||
first_chunk = b"" # Empty audio case
|
temp_writer = TempFileWriter(request.response_format)
|
||||||
except Exception as e:
|
await temp_writer.__aenter__() # Initialize temp file
|
||||||
# Re-raise any errors to be caught by the outer try-except
|
|
||||||
raise RuntimeError(f"Failed to initialize audio stream: {str(e)}") from e
|
# Create response headers
|
||||||
|
headers = {
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Transfer-Encoding": "chunked"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create async generator for streaming
|
||||||
|
async def dual_output():
|
||||||
|
try:
|
||||||
|
# Write chunks to temp file and stream
|
||||||
|
async for chunk in generator:
|
||||||
|
if chunk: # Skip empty chunks
|
||||||
|
await temp_writer.write(chunk)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Get download path and add to headers
|
||||||
|
download_path = await temp_writer.finalize()
|
||||||
|
headers["X-Download-Path"] = download_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in dual output streaming: {e}")
|
||||||
|
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Ensure temp writer is closed
|
||||||
|
if not temp_writer._finalized:
|
||||||
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
# Stream with temp file writing
|
||||||
|
return StreamingResponse(
|
||||||
|
dual_output(),
|
||||||
|
media_type=content_type,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
# If we got here, streaming can begin
|
# Standard streaming without download link
|
||||||
async def safe_stream():
|
|
||||||
yield first_chunk
|
|
||||||
try:
|
|
||||||
async for chunk in generator:
|
|
||||||
yield chunk
|
|
||||||
except Exception as e:
|
|
||||||
# Log the error but don't yield anything - the connection will close
|
|
||||||
logger.error(f"Error during streaming: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Stream audio chunks as they're generated
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
safe_stream(),
|
generator,
|
||||||
media_type=content_type,
|
media_type=content_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
"X-Accel-Buffering": "no", # Disable proxy buffering
|
"X-Accel-Buffering": "no",
|
||||||
"Cache-Control": "no-cache", # Prevent caching
|
"Cache-Control": "no-cache",
|
||||||
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
|
"Transfer-Encoding": "chunked"
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Generate complete audio using public interface
|
# Generate complete audio using public interface
|
||||||
|
@ -268,6 +291,43 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/download/{filename}")
|
||||||
|
async def download_audio_file(filename: str):
|
||||||
|
"""Download a generated audio file from temp storage"""
|
||||||
|
try:
|
||||||
|
from ..core.paths import _find_file, get_content_type
|
||||||
|
|
||||||
|
# Search for file in temp directory
|
||||||
|
file_path = await _find_file(
|
||||||
|
filename=filename,
|
||||||
|
search_paths=[settings.temp_file_dir]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get content type from path helper
|
||||||
|
content_type = await get_content_type(file_path)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
file_path,
|
||||||
|
media_type=content_type,
|
||||||
|
filename=filename,
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Content-Disposition": f"attachment; filename={filename}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error serving download file {filename}: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "Failed to serve audio file",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/audio/voices")
|
@router.get("/audio/voices")
|
||||||
async def list_voices():
|
async def list_voices():
|
||||||
"""List all available voices for text-to-speech"""
|
"""List all available voices for text-to-speech"""
|
||||||
|
|
|
@ -108,15 +108,27 @@ class StreamingAudioWriter:
|
||||||
"aac": {"format": "adts", "codec": "aac"}
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
}[self.format]
|
}[self.format]
|
||||||
|
|
||||||
|
# On finalization, include proper headers and duration metadata
|
||||||
|
parameters = [
|
||||||
|
"-q:a", "2",
|
||||||
|
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||||
|
"-metadata", f"duration={self.total_duration/1000}", # Duration in seconds
|
||||||
|
"-write_id3v1", "1" if self.format == "mp3" else "0", # ID3v1 tag for MP3
|
||||||
|
"-write_id3v2", "1" if self.format == "mp3" else "0" # ID3v2 tag for MP3
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.format == "mp3":
|
||||||
|
# For MP3, ensure proper VBR headers
|
||||||
|
parameters.extend([
|
||||||
|
"-write_vbr", "1",
|
||||||
|
"-vbr_quality", "2"
|
||||||
|
])
|
||||||
|
|
||||||
self.encoder.export(
|
self.encoder.export(
|
||||||
output_buffer,
|
output_buffer,
|
||||||
**format_args,
|
**format_args,
|
||||||
bitrate="192k",
|
bitrate="192k",
|
||||||
parameters=[
|
parameters=parameters
|
||||||
"-q:a", "2",
|
|
||||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
|
||||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
return output_buffer.getvalue()
|
return output_buffer.getvalue()
|
||||||
|
@ -163,10 +175,10 @@ class StreamingAudioWriter:
|
||||||
"aac": {"format": "adts", "codec": "aac"}
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
}[self.format]
|
}[self.format]
|
||||||
|
|
||||||
|
# For chunks, export without duration metadata or XING headers
|
||||||
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
||||||
"-q:a", "2",
|
"-q:a", "2",
|
||||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
"-write_xing", "0" # No XING headers for chunks
|
||||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# Get the encoded data
|
# Get the encoded data
|
||||||
|
|
89
api/src/services/temp_manager.py
Normal file
89
api/src/services/temp_manager.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
"""Temporary file writer for audio downloads"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..core.paths import _scan_directories
|
||||||
|
|
||||||
|
|
||||||
|
class TempFileWriter:
|
||||||
|
"""Handles writing audio chunks to a temp file"""
|
||||||
|
|
||||||
|
def __init__(self, format: str):
|
||||||
|
"""Initialize temp file writer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format: Audio format extension (mp3, wav, etc)
|
||||||
|
"""
|
||||||
|
self.format = format
|
||||||
|
self.temp_file = None
|
||||||
|
self._finalized = False
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry"""
|
||||||
|
# Check temp dir size by scanning
|
||||||
|
total_size = 0
|
||||||
|
entries = await _scan_directories([settings.temp_file_dir])
|
||||||
|
for entry in entries:
|
||||||
|
stat = await aiofiles.os.stat(os.path.join(settings.temp_file_dir, entry))
|
||||||
|
total_size += stat.st_size
|
||||||
|
|
||||||
|
if total_size >= settings.max_temp_dir_size_mb * 1024 * 1024:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=507,
|
||||||
|
detail="Temporary storage full. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temp file with proper extension
|
||||||
|
os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||||
|
temp = tempfile.NamedTemporaryFile(
|
||||||
|
dir=settings.temp_file_dir,
|
||||||
|
delete=False,
|
||||||
|
suffix=f".{self.format}",
|
||||||
|
mode='wb'
|
||||||
|
)
|
||||||
|
self.temp_file = await aiofiles.open(temp.name, mode='wb')
|
||||||
|
self.temp_path = temp.name
|
||||||
|
temp.close() # Close sync file, we'll use async version
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit"""
|
||||||
|
try:
|
||||||
|
if self.temp_file and not self._finalized:
|
||||||
|
await self.temp_file.close()
|
||||||
|
self._finalized = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing temp file: {e}")
|
||||||
|
|
||||||
|
async def write(self, chunk: bytes) -> None:
|
||||||
|
"""Write a chunk of audio data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: Audio data bytes to write
|
||||||
|
"""
|
||||||
|
if self._finalized:
|
||||||
|
raise RuntimeError("Cannot write to finalized temp file")
|
||||||
|
|
||||||
|
await self.temp_file.write(chunk)
|
||||||
|
await self.temp_file.flush()
|
||||||
|
|
||||||
|
async def finalize(self) -> str:
|
||||||
|
"""Close temp file and return download path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to use for downloading the temp file
|
||||||
|
"""
|
||||||
|
if self._finalized:
|
||||||
|
raise RuntimeError("Temp file already finalized")
|
||||||
|
|
||||||
|
await self.temp_file.close()
|
||||||
|
self._finalized = True
|
||||||
|
|
||||||
|
return f"/download/{os.path.basename(self.temp_path)}"
|
|
@ -8,10 +8,10 @@ from .phonemizer import phonemize
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
|
||||||
# Constants for chunk size optimization
|
# Target token ranges
|
||||||
TARGET_MIN_TOKENS = 300
|
TARGET_MIN = 200
|
||||||
TARGET_MAX_TOKENS = 400
|
TARGET_MAX = 350
|
||||||
ABSOLUTE_MAX_TOKENS = 500
|
ABSOLUTE_MAX = 500
|
||||||
|
|
||||||
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||||
|
@ -48,10 +48,6 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def is_chunk_size_optimal(token_count: int) -> bool:
|
|
||||||
"""Check if chunk size is within optimal range."""
|
|
||||||
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
|
|
||||||
|
|
||||||
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
||||||
"""Yield a chunk with consistent logging."""
|
"""Yield a chunk with consistent logging."""
|
||||||
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
||||||
|
@ -76,10 +72,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
return process_text_chunk(text, language)
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
# Target token ranges
|
|
||||||
TARGET_MIN = 300
|
|
||||||
TARGET_MAX = 400
|
|
||||||
ABSOLUTE_MAX = 500
|
|
||||||
|
|
||||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info."""
|
"""Process all sentences and return info."""
|
||||||
|
@ -166,13 +158,23 @@ async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerat
|
||||||
yield chunk_text, clause_tokens
|
yield chunk_text, clause_tokens
|
||||||
|
|
||||||
# Regular sentence handling
|
# Regular sentence handling
|
||||||
|
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
|
||||||
|
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||||
|
# yield current chunk and start new one
|
||||||
|
chunk_text = " ".join(current_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||||
|
yield chunk_text, current_tokens
|
||||||
|
current_chunk = [sentence]
|
||||||
|
current_tokens = tokens
|
||||||
|
current_count = count
|
||||||
elif current_count + count <= TARGET_MAX:
|
elif current_count + count <= TARGET_MAX:
|
||||||
# Keep building chunk while under target max
|
# Keep building chunk while under target max
|
||||||
current_chunk.append(sentence)
|
current_chunk.append(sentence)
|
||||||
current_tokens.extend(tokens)
|
current_tokens.extend(tokens)
|
||||||
current_count += count
|
current_count += count
|
||||||
elif current_count + count <= max_tokens:
|
elif current_count + count <= max_tokens and current_count < TARGET_MIN:
|
||||||
# Accept slightly larger chunk if needed
|
# Only exceed target max if we haven't reached minimum size yet
|
||||||
current_chunk.append(sentence)
|
current_chunk.append(sentence)
|
||||||
current_tokens.extend(tokens)
|
current_tokens.extend(tokens)
|
||||||
current_count += count
|
current_count += count
|
||||||
|
|
|
@ -46,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default=True, # Default to streaming for OpenAI compatibility
|
default=True, # Default to streaming for OpenAI compatibility
|
||||||
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
||||||
)
|
)
|
||||||
|
return_download_link: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||||
|
)
|
||||||
|
|
|
@ -77,9 +77,19 @@ export class App {
|
||||||
// Handle completion
|
// Handle completion
|
||||||
this.audioService.addEventListener('complete', () => {
|
this.audioService.addEventListener('complete', () => {
|
||||||
this.setGenerating(false);
|
this.setGenerating(false);
|
||||||
|
this.showStatus('Preparing file...', 'info');
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle download ready
|
||||||
|
this.audioService.addEventListener('downloadReady', () => {
|
||||||
this.showStatus('Generation complete', 'success');
|
this.showStatus('Generation complete', 'success');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Handle audio end
|
||||||
|
this.audioService.addEventListener('ended', () => {
|
||||||
|
this.playerState.setPlaying(false);
|
||||||
|
});
|
||||||
|
|
||||||
// Handle errors
|
// Handle errors
|
||||||
this.audioService.addEventListener('error', (error) => {
|
this.audioService.addEventListener('error', (error) => {
|
||||||
this.showStatus('Error: ' + error.message, 'error');
|
this.showStatus('Error: ' + error.message, 'error');
|
||||||
|
|
|
@ -9,7 +9,8 @@ export class AudioService {
|
||||||
this.minimumPlaybackSize = 50000; // 50KB minimum before playback
|
this.minimumPlaybackSize = 50000; // 50KB minimum before playback
|
||||||
this.textLength = 0;
|
this.textLength = 0;
|
||||||
this.shouldAutoplay = false;
|
this.shouldAutoplay = false;
|
||||||
this.CHARS_PER_CHUNK = 600; // Estimated chars per chunk
|
this.CHARS_PER_CHUNK = 300; // Estimated chars per chunk
|
||||||
|
this.serverDownloadPath = null; // Server-side download path
|
||||||
}
|
}
|
||||||
|
|
||||||
async streamAudio(text, voice, speed, onProgress) {
|
async streamAudio(text, voice, speed, onProgress) {
|
||||||
|
@ -45,7 +46,8 @@ export class AudioService {
|
||||||
voice: voice,
|
voice: voice,
|
||||||
response_format: 'mp3',
|
response_format: 'mp3',
|
||||||
stream: true,
|
stream: true,
|
||||||
speed: speed
|
speed: speed,
|
||||||
|
return_download_link: true
|
||||||
}),
|
}),
|
||||||
signal: this.controller.signal
|
signal: this.controller.signal
|
||||||
});
|
});
|
||||||
|
@ -58,7 +60,7 @@ export class AudioService {
|
||||||
throw new Error(error.detail?.message || 'Failed to generate speech');
|
throw new Error(error.detail?.message || 'Failed to generate speech');
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.setupAudioStream(response, onProgress, estimatedChunks);
|
await this.setupAudioStream(response.body, response, onProgress, estimatedChunks);
|
||||||
return this.audio;
|
return this.audio;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.cleanup();
|
this.cleanup();
|
||||||
|
@ -66,16 +68,21 @@ export class AudioService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async setupAudioStream(response, onProgress, estimatedTotalSize) {
|
async setupAudioStream(stream, response, onProgress, estimatedTotalSize) {
|
||||||
this.audio = new Audio();
|
this.audio = new Audio();
|
||||||
this.mediaSource = new MediaSource();
|
this.mediaSource = new MediaSource();
|
||||||
this.audio.src = URL.createObjectURL(this.mediaSource);
|
this.audio.src = URL.createObjectURL(this.mediaSource);
|
||||||
|
|
||||||
|
// Set up ended event handler
|
||||||
|
this.audio.addEventListener('ended', () => {
|
||||||
|
this.dispatchEvent('ended');
|
||||||
|
});
|
||||||
|
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
this.mediaSource.addEventListener('sourceopen', async () => {
|
this.mediaSource.addEventListener('sourceopen', async () => {
|
||||||
try {
|
try {
|
||||||
this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg');
|
this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg');
|
||||||
await this.processStream(response.body, onProgress, estimatedTotalSize);
|
await this.processStream(stream, response, onProgress, estimatedTotalSize);
|
||||||
resolve();
|
resolve();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
reject(error);
|
reject(error);
|
||||||
|
@ -84,11 +91,17 @@ export class AudioService {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async processStream(stream, onProgress, estimatedChunks) {
|
async processStream(stream, response, onProgress, estimatedChunks) {
|
||||||
const reader = stream.getReader();
|
const reader = stream.getReader();
|
||||||
let hasStartedPlaying = false;
|
let hasStartedPlaying = false;
|
||||||
let receivedChunks = 0;
|
let receivedChunks = 0;
|
||||||
|
|
||||||
|
// Check for download path in response headers
|
||||||
|
const downloadPath = response.headers.get('X-Download-Path');
|
||||||
|
if (downloadPath) {
|
||||||
|
this.serverDownloadPath = downloadPath;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
const {value, done} = await reader.read();
|
const {value, done} = await reader.read();
|
||||||
|
@ -245,6 +258,7 @@ export class AudioService {
|
||||||
this.sourceBuffer = null;
|
this.sourceBuffer = null;
|
||||||
this.chunks = [];
|
this.chunks = [];
|
||||||
this.textLength = 0;
|
this.textLength = 0;
|
||||||
|
this.serverDownloadPath = null;
|
||||||
|
|
||||||
// Force a hard refresh of the page to ensure clean state
|
// Force a hard refresh of the page to ensure clean state
|
||||||
window.location.reload();
|
window.location.reload();
|
||||||
|
@ -277,17 +291,25 @@ export class AudioService {
|
||||||
this.sourceBuffer = null;
|
this.sourceBuffer = null;
|
||||||
this.chunks = [];
|
this.chunks = [];
|
||||||
this.textLength = 0;
|
this.textLength = 0;
|
||||||
|
this.serverDownloadPath = null;
|
||||||
}
|
}
|
||||||
|
getDownloadUrl() {
|
||||||
getDownloadUrl() {
|
// Check for server-side download link first
|
||||||
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
|
const downloadPath = this.serverDownloadPath;
|
||||||
|
if (downloadPath) {
|
||||||
// Get the buffered data from MediaSource
|
return downloadPath;
|
||||||
const buffered = this.sourceBuffer.buffered;
|
}
|
||||||
if (buffered.length === 0) return null;
|
|
||||||
|
// Fall back to client-side blob URL
|
||||||
// Create blob from the original chunks
|
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
|
||||||
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
|
|
||||||
|
// Get the buffered data from MediaSource
|
||||||
|
const buffered = this.sourceBuffer.buffered;
|
||||||
|
if (buffered.length === 0) return null;
|
||||||
|
|
||||||
|
// Create blob from the original chunks
|
||||||
|
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
|
||||||
|
return URL.createObjectURL(blob);
|
||||||
return URL.createObjectURL(blob);
|
return URL.createObjectURL(blob);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue