Implement temporary file management on openai endpoint, whole file downloads

This commit is contained in:
remsky 2025-01-29 04:09:38 -07:00
parent 355ec54f78
commit 946e322242
10 changed files with 346 additions and 64 deletions

View file

@ -32,6 +32,11 @@ class Settings(BaseSettings):
cors_origins: list[str] = ["*"] # CORS origins for web player cors_origins: list[str] = ["*"] # CORS origins for web player
cors_enabled: bool = True # Whether to enable CORS cors_enabled: bool = True # Whether to enable CORS
# Temp File Settings
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
temp_file_max_age_hours: int = 1 # Remove temp files older than 1 hour
class Config: class Config:
env_file = ".env" env_file = ".env"

View file

@ -337,4 +337,78 @@ async def get_content_type(path: str) -> str:
async def verify_model_path(model_path: str) -> bool: async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path.""" """Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path) return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.temp_file_max_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total

View file

@ -48,6 +48,10 @@ async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization""" """Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager from .inference.voice_manager import get_manager as get_voice_manager
from .core.paths import cleanup_temp_files
# Clean old temp files on startup
await cleanup_temp_files()
logger.info("Loading TTS model and voice packs...") logger.info("Loading TTS model and voice packs...")

View file

@ -5,7 +5,7 @@ import os
from typing import AsyncGenerator, Dict, List, Union from typing import AsyncGenerator, Dict, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, FileResponse
from loguru import logger from loguru import logger
from ..services.audio import AudioService from ..services.audio import AudioService
@ -179,36 +179,59 @@ async def create_speech(
# Create generator but don't start it yet # Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request) generator = stream_audio_chunks(tts_service, request, client_request)
# Test the generator by attempting to get first chunk # If download link requested, wrap generator with temp file writer
try: if request.return_download_link:
first_chunk = await anext(generator) from ..services.temp_manager import TempFileWriter
except StopAsyncIteration:
first_chunk = b"" # Empty audio case temp_writer = TempFileWriter(request.response_format)
except Exception as e: await temp_writer.__aenter__() # Initialize temp file
# Re-raise any errors to be caught by the outer try-except
raise RuntimeError(f"Failed to initialize audio stream: {str(e)}") from e # Create response headers
headers = {
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked"
}
# Create async generator for streaming
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
yield chunk
# Get download path and add to headers
download_path = await temp_writer.finalize()
headers["X-Download-Path"] = download_path
except Exception as e:
logger.error(f"Error in dual output streaming: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing
return StreamingResponse(
dual_output(),
media_type=content_type,
headers=headers
)
# If we got here, streaming can begin # Standard streaming without download link
async def safe_stream():
yield first_chunk
try:
async for chunk in generator:
yield chunk
except Exception as e:
# Log the error but don't yield anything - the connection will close
logger.error(f"Error during streaming: {str(e)}")
raise
# Stream audio chunks as they're generated
return StreamingResponse( return StreamingResponse(
safe_stream(), generator,
media_type=content_type, media_type=content_type,
headers={ headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}", "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no", # Disable proxy buffering "X-Accel-Buffering": "no",
"Cache-Control": "no-cache", # Prevent caching "Cache-Control": "no-cache",
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding "Transfer-Encoding": "chunked"
}, }
) )
else: else:
# Generate complete audio using public interface # Generate complete audio using public interface
@ -268,6 +291,43 @@ async def create_speech(
) )
@router.get("/download/{filename}")
async def download_audio_file(filename: str):
"""Download a generated audio file from temp storage"""
try:
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(
filename=filename,
search_paths=[settings.temp_file_dir]
)
# Get content type from path helper
content_type = await get_content_type(file_path)
return FileResponse(
file_path,
media_type=content_type,
filename=filename,
headers={
"Cache-Control": "no-cache",
"Content-Disposition": f"attachment; filename={filename}"
}
)
except Exception as e:
logger.error(f"Error serving download file {filename}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to serve audio file",
"type": "server_error"
}
)
@router.get("/audio/voices") @router.get("/audio/voices")
async def list_voices(): async def list_voices():
"""List all available voices for text-to-speech""" """List all available voices for text-to-speech"""

View file

@ -108,15 +108,27 @@ class StreamingAudioWriter:
"aac": {"format": "adts", "codec": "aac"} "aac": {"format": "adts", "codec": "aac"}
}[self.format] }[self.format]
# On finalization, include proper headers and duration metadata
parameters = [
"-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}", # Duration in seconds
"-write_id3v1", "1" if self.format == "mp3" else "0", # ID3v1 tag for MP3
"-write_id3v2", "1" if self.format == "mp3" else "0" # ID3v2 tag for MP3
]
if self.format == "mp3":
# For MP3, ensure proper VBR headers
parameters.extend([
"-write_vbr", "1",
"-vbr_quality", "2"
])
self.encoder.export( self.encoder.export(
output_buffer, output_buffer,
**format_args, **format_args,
bitrate="192k", bitrate="192k",
parameters=[ parameters=parameters
"-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
]
) )
self.encoder = None self.encoder = None
return output_buffer.getvalue() return output_buffer.getvalue()
@ -163,10 +175,10 @@ class StreamingAudioWriter:
"aac": {"format": "adts", "codec": "aac"} "aac": {"format": "adts", "codec": "aac"}
}[self.format] }[self.format]
# For chunks, export without duration metadata or XING headers
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[ self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
"-q:a", "2", "-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only "-write_xing", "0" # No XING headers for chunks
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
]) ])
# Get the encoded data # Get the encoded data

View file

@ -0,0 +1,89 @@
"""Temporary file writer for audio downloads"""
import os
import tempfile
from typing import Optional
import aiofiles
from fastapi import HTTPException
from loguru import logger
from ..core.config import settings
from ..core.paths import _scan_directories
class TempFileWriter:
"""Handles writing audio chunks to a temp file"""
def __init__(self, format: str):
"""Initialize temp file writer
Args:
format: Audio format extension (mp3, wav, etc)
"""
self.format = format
self.temp_file = None
self._finalized = False
async def __aenter__(self):
"""Async context manager entry"""
# Check temp dir size by scanning
total_size = 0
entries = await _scan_directories([settings.temp_file_dir])
for entry in entries:
stat = await aiofiles.os.stat(os.path.join(settings.temp_file_dir, entry))
total_size += stat.st_size
if total_size >= settings.max_temp_dir_size_mb * 1024 * 1024:
raise HTTPException(
status_code=507,
detail="Temporary storage full. Please try again later."
)
# Create temp file with proper extension
os.makedirs(settings.temp_file_dir, exist_ok=True)
temp = tempfile.NamedTemporaryFile(
dir=settings.temp_file_dir,
delete=False,
suffix=f".{self.format}",
mode='wb'
)
self.temp_file = await aiofiles.open(temp.name, mode='wb')
self.temp_path = temp.name
temp.close() # Close sync file, we'll use async version
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
try:
if self.temp_file and not self._finalized:
await self.temp_file.close()
self._finalized = True
except Exception as e:
logger.error(f"Error closing temp file: {e}")
async def write(self, chunk: bytes) -> None:
"""Write a chunk of audio data
Args:
chunk: Audio data bytes to write
"""
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
await self.temp_file.write(chunk)
await self.temp_file.flush()
async def finalize(self) -> str:
"""Close temp file and return download path
Returns:
Path to use for downloading the temp file
"""
if self._finalized:
raise RuntimeError("Temp file already finalized")
await self.temp_file.close()
self._finalized = True
return f"/download/{os.path.basename(self.temp_path)}"

View file

@ -8,10 +8,10 @@ from .phonemizer import phonemize
from .normalizer import normalize_text from .normalizer import normalize_text
from .vocabulary import tokenize from .vocabulary import tokenize
# Constants for chunk size optimization # Target token ranges
TARGET_MIN_TOKENS = 300 TARGET_MIN = 200
TARGET_MAX_TOKENS = 400 TARGET_MAX = 350
ABSOLUTE_MAX_TOKENS = 500 ABSOLUTE_MAX = 500
def process_text_chunk(text: str, language: str = "a") -> List[int]: def process_text_chunk(text: str, language: str = "a") -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization. """Process a chunk of text through normalization, phonemization, and tokenization.
@ -48,10 +48,6 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
return tokens return tokens
def is_chunk_size_optimal(token_count: int) -> bool:
"""Check if chunk size is within optimal range."""
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]: async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging.""" """Yield a chunk with consistent logging."""
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)") logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
@ -76,10 +72,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language) return process_text_chunk(text, language)
# Target token ranges
TARGET_MIN = 300
TARGET_MAX = 400
ABSOLUTE_MAX = 500
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info."""
@ -166,13 +158,23 @@ async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerat
yield chunk_text, clause_tokens yield chunk_text, clause_tokens
# Regular sentence handling # Regular sentence handling
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= TARGET_MAX: elif current_count + count <= TARGET_MAX:
# Keep building chunk while under target max # Keep building chunk while under target max
current_chunk.append(sentence) current_chunk.append(sentence)
current_tokens.extend(tokens) current_tokens.extend(tokens)
current_count += count current_count += count
elif current_count + count <= max_tokens: elif current_count + count <= max_tokens and current_count < TARGET_MIN:
# Accept slightly larger chunk if needed # Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence) current_chunk.append(sentence)
current_tokens.extend(tokens) current_tokens.extend(tokens)
current_count += count current_count += count

View file

@ -46,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
default=True, # Default to streaming for OpenAI compatibility default=True, # Default to streaming for OpenAI compatibility
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.", description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
) )
return_download_link: bool = Field(
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)

View file

@ -77,9 +77,19 @@ export class App {
// Handle completion // Handle completion
this.audioService.addEventListener('complete', () => { this.audioService.addEventListener('complete', () => {
this.setGenerating(false); this.setGenerating(false);
this.showStatus('Preparing file...', 'info');
});
// Handle download ready
this.audioService.addEventListener('downloadReady', () => {
this.showStatus('Generation complete', 'success'); this.showStatus('Generation complete', 'success');
}); });
// Handle audio end
this.audioService.addEventListener('ended', () => {
this.playerState.setPlaying(false);
});
// Handle errors // Handle errors
this.audioService.addEventListener('error', (error) => { this.audioService.addEventListener('error', (error) => {
this.showStatus('Error: ' + error.message, 'error'); this.showStatus('Error: ' + error.message, 'error');

View file

@ -9,7 +9,8 @@ export class AudioService {
this.minimumPlaybackSize = 50000; // 50KB minimum before playback this.minimumPlaybackSize = 50000; // 50KB minimum before playback
this.textLength = 0; this.textLength = 0;
this.shouldAutoplay = false; this.shouldAutoplay = false;
this.CHARS_PER_CHUNK = 600; // Estimated chars per chunk this.CHARS_PER_CHUNK = 300; // Estimated chars per chunk
this.serverDownloadPath = null; // Server-side download path
} }
async streamAudio(text, voice, speed, onProgress) { async streamAudio(text, voice, speed, onProgress) {
@ -45,7 +46,8 @@ export class AudioService {
voice: voice, voice: voice,
response_format: 'mp3', response_format: 'mp3',
stream: true, stream: true,
speed: speed speed: speed,
return_download_link: true
}), }),
signal: this.controller.signal signal: this.controller.signal
}); });
@ -58,7 +60,7 @@ export class AudioService {
throw new Error(error.detail?.message || 'Failed to generate speech'); throw new Error(error.detail?.message || 'Failed to generate speech');
} }
await this.setupAudioStream(response, onProgress, estimatedChunks); await this.setupAudioStream(response.body, response, onProgress, estimatedChunks);
return this.audio; return this.audio;
} catch (error) { } catch (error) {
this.cleanup(); this.cleanup();
@ -66,16 +68,21 @@ export class AudioService {
} }
} }
async setupAudioStream(response, onProgress, estimatedTotalSize) { async setupAudioStream(stream, response, onProgress, estimatedTotalSize) {
this.audio = new Audio(); this.audio = new Audio();
this.mediaSource = new MediaSource(); this.mediaSource = new MediaSource();
this.audio.src = URL.createObjectURL(this.mediaSource); this.audio.src = URL.createObjectURL(this.mediaSource);
// Set up ended event handler
this.audio.addEventListener('ended', () => {
this.dispatchEvent('ended');
});
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
this.mediaSource.addEventListener('sourceopen', async () => { this.mediaSource.addEventListener('sourceopen', async () => {
try { try {
this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg'); this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg');
await this.processStream(response.body, onProgress, estimatedTotalSize); await this.processStream(stream, response, onProgress, estimatedTotalSize);
resolve(); resolve();
} catch (error) { } catch (error) {
reject(error); reject(error);
@ -84,11 +91,17 @@ export class AudioService {
}); });
} }
async processStream(stream, onProgress, estimatedChunks) { async processStream(stream, response, onProgress, estimatedChunks) {
const reader = stream.getReader(); const reader = stream.getReader();
let hasStartedPlaying = false; let hasStartedPlaying = false;
let receivedChunks = 0; let receivedChunks = 0;
// Check for download path in response headers
const downloadPath = response.headers.get('X-Download-Path');
if (downloadPath) {
this.serverDownloadPath = downloadPath;
}
try { try {
while (true) { while (true) {
const {value, done} = await reader.read(); const {value, done} = await reader.read();
@ -245,6 +258,7 @@ export class AudioService {
this.sourceBuffer = null; this.sourceBuffer = null;
this.chunks = []; this.chunks = [];
this.textLength = 0; this.textLength = 0;
this.serverDownloadPath = null;
// Force a hard refresh of the page to ensure clean state // Force a hard refresh of the page to ensure clean state
window.location.reload(); window.location.reload();
@ -277,17 +291,25 @@ export class AudioService {
this.sourceBuffer = null; this.sourceBuffer = null;
this.chunks = []; this.chunks = [];
this.textLength = 0; this.textLength = 0;
this.serverDownloadPath = null;
} }
getDownloadUrl() {
getDownloadUrl() { // Check for server-side download link first
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null; const downloadPath = this.serverDownloadPath;
if (downloadPath) {
// Get the buffered data from MediaSource return downloadPath;
const buffered = this.sourceBuffer.buffered; }
if (buffered.length === 0) return null;
// Fall back to client-side blob URL
// Create blob from the original chunks if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
// Get the buffered data from MediaSource
const buffered = this.sourceBuffer.buffered;
if (buffered.length === 0) return null;
// Create blob from the original chunks
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
return URL.createObjectURL(blob);
return URL.createObjectURL(blob); return URL.createObjectURL(blob);
} }
} }