mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Implement temporary file management on openai endpoint, whole file downloads
This commit is contained in:
parent
355ec54f78
commit
946e322242
10 changed files with 346 additions and 64 deletions
|
@ -32,6 +32,11 @@ class Settings(BaseSettings):
|
|||
cors_origins: list[str] = ["*"] # CORS origins for web player
|
||||
cors_enabled: bool = True # Whether to enable CORS
|
||||
|
||||
# Temp File Settings
|
||||
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
|
||||
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
|
||||
temp_file_max_age_hours: int = 1 # Remove temp files older than 1 hour
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
|
|
@ -338,3 +338,77 @@ async def get_content_type(path: str) -> str:
|
|||
async def verify_model_path(model_path: str) -> bool:
|
||||
"""Verify model file exists at path."""
|
||||
return await aiofiles.os.path.exists(model_path)
|
||||
|
||||
|
||||
async def cleanup_temp_files() -> None:
|
||||
"""Clean up old temp files on startup"""
|
||||
try:
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
stat = await aiofiles.os.stat(entry.path)
|
||||
max_age = stat.st_mtime + (settings.temp_file_max_age_hours * 3600)
|
||||
if max_age < stat.st_mtime:
|
||||
try:
|
||||
await aiofiles.os.remove(entry.path)
|
||||
logger.info(f"Cleaned up old temp file: {entry.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning temp files: {e}")
|
||||
|
||||
|
||||
async def get_temp_file_path(filename: str) -> str:
|
||||
"""Get path to temporary audio file.
|
||||
|
||||
Args:
|
||||
filename: Name of temp file
|
||||
|
||||
Returns:
|
||||
Absolute path to temp file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If temp directory does not exist
|
||||
"""
|
||||
temp_path = os.path.join(settings.temp_file_dir, filename)
|
||||
|
||||
# Ensure temp directory exists
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
|
||||
return temp_path
|
||||
|
||||
|
||||
async def list_temp_files() -> List[str]:
|
||||
"""List temporary audio files.
|
||||
|
||||
Returns:
|
||||
List of temp file names
|
||||
"""
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
return []
|
||||
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
return [entry.name for entry in entries if entry.is_file()]
|
||||
|
||||
|
||||
async def get_temp_dir_size() -> int:
|
||||
"""Get total size of temp directory in bytes.
|
||||
|
||||
Returns:
|
||||
Size in bytes
|
||||
"""
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
stat = await aiofiles.os.stat(entry.path)
|
||||
total += stat.st_size
|
||||
return total
|
|
@ -48,6 +48,10 @@ async def lifespan(app: FastAPI):
|
|||
"""Lifespan context manager for model initialization"""
|
||||
from .inference.model_manager import get_manager
|
||||
from .inference.voice_manager import get_manager as get_voice_manager
|
||||
from .core.paths import cleanup_temp_files
|
||||
|
||||
# Clean old temp files on startup
|
||||
await cleanup_temp_files()
|
||||
|
||||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
|
@ -179,36 +179,59 @@ async def create_speech(
|
|||
# Create generator but don't start it yet
|
||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||
|
||||
# Test the generator by attempting to get first chunk
|
||||
try:
|
||||
first_chunk = await anext(generator)
|
||||
except StopAsyncIteration:
|
||||
first_chunk = b"" # Empty audio case
|
||||
except Exception as e:
|
||||
# Re-raise any errors to be caught by the outer try-except
|
||||
raise RuntimeError(f"Failed to initialize audio stream: {str(e)}") from e
|
||||
# If download link requested, wrap generator with temp file writer
|
||||
if request.return_download_link:
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
|
||||
# If we got here, streaming can begin
|
||||
async def safe_stream():
|
||||
yield first_chunk
|
||||
temp_writer = TempFileWriter(request.response_format)
|
||||
await temp_writer.__aenter__() # Initialize temp file
|
||||
|
||||
# Create response headers
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked"
|
||||
}
|
||||
|
||||
# Create async generator for streaming
|
||||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
# Log the error but don't yield anything - the connection will close
|
||||
logger.error(f"Error during streaming: {str(e)}")
|
||||
raise
|
||||
|
||||
# Stream audio chunks as they're generated
|
||||
# Get download path and add to headers
|
||||
download_path = await temp_writer.finalize()
|
||||
headers["X-Download-Path"] = download_path
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dual output streaming: {e}")
|
||||
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||
raise
|
||||
finally:
|
||||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(
|
||||
safe_stream(),
|
||||
dual_output(),
|
||||
media_type=content_type,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no", # Disable proxy buffering
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
|
||||
},
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
|
@ -268,6 +291,43 @@ async def create_speech(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/download/{filename}")
|
||||
async def download_audio_file(filename: str):
|
||||
"""Download a generated audio file from temp storage"""
|
||||
try:
|
||||
from ..core.paths import _find_file, get_content_type
|
||||
|
||||
# Search for file in temp directory
|
||||
file_path = await _find_file(
|
||||
filename=filename,
|
||||
search_paths=[settings.temp_file_dir]
|
||||
)
|
||||
|
||||
# Get content type from path helper
|
||||
content_type = await get_content_type(file_path)
|
||||
|
||||
return FileResponse(
|
||||
file_path,
|
||||
media_type=content_type,
|
||||
filename=filename,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Content-Disposition": f"attachment; filename={filename}"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving download file {filename}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to serve audio file",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices():
|
||||
"""List all available voices for text-to-speech"""
|
||||
|
|
|
@ -108,15 +108,27 @@ class StreamingAudioWriter:
|
|||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
# On finalization, include proper headers and duration metadata
|
||||
parameters = [
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||
"-metadata", f"duration={self.total_duration/1000}", # Duration in seconds
|
||||
"-write_id3v1", "1" if self.format == "mp3" else "0", # ID3v1 tag for MP3
|
||||
"-write_id3v2", "1" if self.format == "mp3" else "0" # ID3v2 tag for MP3
|
||||
]
|
||||
|
||||
if self.format == "mp3":
|
||||
# For MP3, ensure proper VBR headers
|
||||
parameters.extend([
|
||||
"-write_vbr", "1",
|
||||
"-vbr_quality", "2"
|
||||
])
|
||||
|
||||
self.encoder.export(
|
||||
output_buffer,
|
||||
**format_args,
|
||||
bitrate="192k",
|
||||
parameters=[
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||
]
|
||||
parameters=parameters
|
||||
)
|
||||
self.encoder = None
|
||||
return output_buffer.getvalue()
|
||||
|
@ -163,10 +175,10 @@ class StreamingAudioWriter:
|
|||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
# For chunks, export without duration metadata or XING headers
|
||||
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||
"-write_xing", "0" # No XING headers for chunks
|
||||
])
|
||||
|
||||
# Get the encoded data
|
||||
|
|
89
api/src/services/temp_manager.py
Normal file
89
api/src/services/temp_manager.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
"""Temporary file writer for audio downloads"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.paths import _scan_directories
|
||||
|
||||
|
||||
class TempFileWriter:
|
||||
"""Handles writing audio chunks to a temp file"""
|
||||
|
||||
def __init__(self, format: str):
|
||||
"""Initialize temp file writer
|
||||
|
||||
Args:
|
||||
format: Audio format extension (mp3, wav, etc)
|
||||
"""
|
||||
self.format = format
|
||||
self.temp_file = None
|
||||
self._finalized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
# Check temp dir size by scanning
|
||||
total_size = 0
|
||||
entries = await _scan_directories([settings.temp_file_dir])
|
||||
for entry in entries:
|
||||
stat = await aiofiles.os.stat(os.path.join(settings.temp_file_dir, entry))
|
||||
total_size += stat.st_size
|
||||
|
||||
if total_size >= settings.max_temp_dir_size_mb * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=507,
|
||||
detail="Temporary storage full. Please try again later."
|
||||
)
|
||||
|
||||
# Create temp file with proper extension
|
||||
os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
dir=settings.temp_file_dir,
|
||||
delete=False,
|
||||
suffix=f".{self.format}",
|
||||
mode='wb'
|
||||
)
|
||||
self.temp_file = await aiofiles.open(temp.name, mode='wb')
|
||||
self.temp_path = temp.name
|
||||
temp.close() # Close sync file, we'll use async version
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
try:
|
||||
if self.temp_file and not self._finalized:
|
||||
await self.temp_file.close()
|
||||
self._finalized = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing temp file: {e}")
|
||||
|
||||
async def write(self, chunk: bytes) -> None:
|
||||
"""Write a chunk of audio data
|
||||
|
||||
Args:
|
||||
chunk: Audio data bytes to write
|
||||
"""
|
||||
if self._finalized:
|
||||
raise RuntimeError("Cannot write to finalized temp file")
|
||||
|
||||
await self.temp_file.write(chunk)
|
||||
await self.temp_file.flush()
|
||||
|
||||
async def finalize(self) -> str:
|
||||
"""Close temp file and return download path
|
||||
|
||||
Returns:
|
||||
Path to use for downloading the temp file
|
||||
"""
|
||||
if self._finalized:
|
||||
raise RuntimeError("Temp file already finalized")
|
||||
|
||||
await self.temp_file.close()
|
||||
self._finalized = True
|
||||
|
||||
return f"/download/{os.path.basename(self.temp_path)}"
|
|
@ -8,10 +8,10 @@ from .phonemizer import phonemize
|
|||
from .normalizer import normalize_text
|
||||
from .vocabulary import tokenize
|
||||
|
||||
# Constants for chunk size optimization
|
||||
TARGET_MIN_TOKENS = 300
|
||||
TARGET_MAX_TOKENS = 400
|
||||
ABSOLUTE_MAX_TOKENS = 500
|
||||
# Target token ranges
|
||||
TARGET_MIN = 200
|
||||
TARGET_MAX = 350
|
||||
ABSOLUTE_MAX = 500
|
||||
|
||||
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||
|
@ -48,10 +48,6 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
return tokens
|
||||
|
||||
def is_chunk_size_optimal(token_count: int) -> bool:
|
||||
"""Check if chunk size is within optimal range."""
|
||||
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
|
||||
|
||||
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
||||
"""Yield a chunk with consistent logging."""
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
||||
|
@ -76,10 +72,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
return process_text_chunk(text, language)
|
||||
|
||||
# Target token ranges
|
||||
TARGET_MIN = 300
|
||||
TARGET_MAX = 400
|
||||
ABSOLUTE_MAX = 500
|
||||
|
||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info."""
|
||||
|
@ -166,13 +158,23 @@ async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerat
|
|||
yield chunk_text, clause_tokens
|
||||
|
||||
# Regular sentence handling
|
||||
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
|
||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||
# yield current chunk and start new one
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
elif current_count + count <= TARGET_MAX:
|
||||
# Keep building chunk while under target max
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
elif current_count + count <= max_tokens:
|
||||
# Accept slightly larger chunk if needed
|
||||
elif current_count + count <= max_tokens and current_count < TARGET_MIN:
|
||||
# Only exceed target max if we haven't reached minimum size yet
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
|
|
|
@ -46,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
|
|||
default=True, # Default to streaming for OpenAI compatibility
|
||||
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
||||
)
|
||||
return_download_link: bool = Field(
|
||||
default=False,
|
||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||
)
|
||||
|
|
|
@ -77,9 +77,19 @@ export class App {
|
|||
// Handle completion
|
||||
this.audioService.addEventListener('complete', () => {
|
||||
this.setGenerating(false);
|
||||
this.showStatus('Preparing file...', 'info');
|
||||
});
|
||||
|
||||
// Handle download ready
|
||||
this.audioService.addEventListener('downloadReady', () => {
|
||||
this.showStatus('Generation complete', 'success');
|
||||
});
|
||||
|
||||
// Handle audio end
|
||||
this.audioService.addEventListener('ended', () => {
|
||||
this.playerState.setPlaying(false);
|
||||
});
|
||||
|
||||
// Handle errors
|
||||
this.audioService.addEventListener('error', (error) => {
|
||||
this.showStatus('Error: ' + error.message, 'error');
|
||||
|
|
|
@ -9,7 +9,8 @@ export class AudioService {
|
|||
this.minimumPlaybackSize = 50000; // 50KB minimum before playback
|
||||
this.textLength = 0;
|
||||
this.shouldAutoplay = false;
|
||||
this.CHARS_PER_CHUNK = 600; // Estimated chars per chunk
|
||||
this.CHARS_PER_CHUNK = 300; // Estimated chars per chunk
|
||||
this.serverDownloadPath = null; // Server-side download path
|
||||
}
|
||||
|
||||
async streamAudio(text, voice, speed, onProgress) {
|
||||
|
@ -45,7 +46,8 @@ export class AudioService {
|
|||
voice: voice,
|
||||
response_format: 'mp3',
|
||||
stream: true,
|
||||
speed: speed
|
||||
speed: speed,
|
||||
return_download_link: true
|
||||
}),
|
||||
signal: this.controller.signal
|
||||
});
|
||||
|
@ -58,7 +60,7 @@ export class AudioService {
|
|||
throw new Error(error.detail?.message || 'Failed to generate speech');
|
||||
}
|
||||
|
||||
await this.setupAudioStream(response, onProgress, estimatedChunks);
|
||||
await this.setupAudioStream(response.body, response, onProgress, estimatedChunks);
|
||||
return this.audio;
|
||||
} catch (error) {
|
||||
this.cleanup();
|
||||
|
@ -66,16 +68,21 @@ export class AudioService {
|
|||
}
|
||||
}
|
||||
|
||||
async setupAudioStream(response, onProgress, estimatedTotalSize) {
|
||||
async setupAudioStream(stream, response, onProgress, estimatedTotalSize) {
|
||||
this.audio = new Audio();
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audio.src = URL.createObjectURL(this.mediaSource);
|
||||
|
||||
// Set up ended event handler
|
||||
this.audio.addEventListener('ended', () => {
|
||||
this.dispatchEvent('ended');
|
||||
});
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
this.mediaSource.addEventListener('sourceopen', async () => {
|
||||
try {
|
||||
this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg');
|
||||
await this.processStream(response.body, onProgress, estimatedTotalSize);
|
||||
await this.processStream(stream, response, onProgress, estimatedTotalSize);
|
||||
resolve();
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
|
@ -84,11 +91,17 @@ export class AudioService {
|
|||
});
|
||||
}
|
||||
|
||||
async processStream(stream, onProgress, estimatedChunks) {
|
||||
async processStream(stream, response, onProgress, estimatedChunks) {
|
||||
const reader = stream.getReader();
|
||||
let hasStartedPlaying = false;
|
||||
let receivedChunks = 0;
|
||||
|
||||
// Check for download path in response headers
|
||||
const downloadPath = response.headers.get('X-Download-Path');
|
||||
if (downloadPath) {
|
||||
this.serverDownloadPath = downloadPath;
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const {value, done} = await reader.read();
|
||||
|
@ -245,6 +258,7 @@ export class AudioService {
|
|||
this.sourceBuffer = null;
|
||||
this.chunks = [];
|
||||
this.textLength = 0;
|
||||
this.serverDownloadPath = null;
|
||||
|
||||
// Force a hard refresh of the page to ensure clean state
|
||||
window.location.reload();
|
||||
|
@ -277,9 +291,16 @@ export class AudioService {
|
|||
this.sourceBuffer = null;
|
||||
this.chunks = [];
|
||||
this.textLength = 0;
|
||||
this.serverDownloadPath = null;
|
||||
}
|
||||
getDownloadUrl() {
|
||||
// Check for server-side download link first
|
||||
const downloadPath = this.serverDownloadPath;
|
||||
if (downloadPath) {
|
||||
return downloadPath;
|
||||
}
|
||||
|
||||
getDownloadUrl() {
|
||||
// Fall back to client-side blob URL
|
||||
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
|
||||
|
||||
// Get the buffered data from MediaSource
|
||||
|
@ -288,6 +309,7 @@ export class AudioService {
|
|||
|
||||
// Create blob from the original chunks
|
||||
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
|
||||
return URL.createObjectURL(blob);
|
||||
return URL.createObjectURL(blob);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue