Implement temporary file management on openai endpoint, whole file downloads

This commit is contained in:
remsky 2025-01-29 04:09:38 -07:00
parent 355ec54f78
commit 946e322242
10 changed files with 346 additions and 64 deletions

View file

@ -32,6 +32,11 @@ class Settings(BaseSettings):
cors_origins: list[str] = ["*"] # CORS origins for web player
cors_enabled: bool = True # Whether to enable CORS
# Temp File Settings
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
temp_file_max_age_hours: int = 1 # Remove temp files older than 1 hour
class Config:
env_file = ".env"

View file

@ -337,4 +337,78 @@ async def get_content_type(path: str) -> str:
async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path)
return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.temp_file_max_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total

View file

@ -48,6 +48,10 @@ async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
from .core.paths import cleanup_temp_files
# Clean old temp files on startup
await cleanup_temp_files()
logger.info("Loading TTS model and voice packs...")

View file

@ -5,7 +5,7 @@ import os
from typing import AsyncGenerator, Dict, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, FileResponse
from loguru import logger
from ..services.audio import AudioService
@ -179,36 +179,59 @@ async def create_speech(
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request)
# Test the generator by attempting to get first chunk
try:
first_chunk = await anext(generator)
except StopAsyncIteration:
first_chunk = b"" # Empty audio case
except Exception as e:
# Re-raise any errors to be caught by the outer try-except
raise RuntimeError(f"Failed to initialize audio stream: {str(e)}") from e
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
temp_writer = TempFileWriter(request.response_format)
await temp_writer.__aenter__() # Initialize temp file
# Create response headers
headers = {
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked"
}
# Create async generator for streaming
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
yield chunk
# Get download path and add to headers
download_path = await temp_writer.finalize()
headers["X-Download-Path"] = download_path
except Exception as e:
logger.error(f"Error in dual output streaming: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing
return StreamingResponse(
dual_output(),
media_type=content_type,
headers=headers
)
# If we got here, streaming can begin
async def safe_stream():
yield first_chunk
try:
async for chunk in generator:
yield chunk
except Exception as e:
# Log the error but don't yield anything - the connection will close
logger.error(f"Error during streaming: {str(e)}")
raise
# Stream audio chunks as they're generated
# Standard streaming without download link
return StreamingResponse(
safe_stream(),
generator,
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no", # Disable proxy buffering
"Cache-Control": "no-cache", # Prevent caching
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
},
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked"
}
)
else:
# Generate complete audio using public interface
@ -268,6 +291,43 @@ async def create_speech(
)
@router.get("/download/{filename}")
async def download_audio_file(filename: str):
"""Download a generated audio file from temp storage"""
try:
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(
filename=filename,
search_paths=[settings.temp_file_dir]
)
# Get content type from path helper
content_type = await get_content_type(file_path)
return FileResponse(
file_path,
media_type=content_type,
filename=filename,
headers={
"Cache-Control": "no-cache",
"Content-Disposition": f"attachment; filename={filename}"
}
)
except Exception as e:
logger.error(f"Error serving download file {filename}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to serve audio file",
"type": "server_error"
}
)
@router.get("/audio/voices")
async def list_voices():
"""List all available voices for text-to-speech"""

View file

@ -108,15 +108,27 @@ class StreamingAudioWriter:
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
# On finalization, include proper headers and duration metadata
parameters = [
"-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}", # Duration in seconds
"-write_id3v1", "1" if self.format == "mp3" else "0", # ID3v1 tag for MP3
"-write_id3v2", "1" if self.format == "mp3" else "0" # ID3v2 tag for MP3
]
if self.format == "mp3":
# For MP3, ensure proper VBR headers
parameters.extend([
"-write_vbr", "1",
"-vbr_quality", "2"
])
self.encoder.export(
output_buffer,
**format_args,
bitrate="192k",
parameters=[
"-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
]
parameters=parameters
)
self.encoder = None
return output_buffer.getvalue()
@ -163,10 +175,10 @@ class StreamingAudioWriter:
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
# For chunks, export without duration metadata or XING headers
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
"-q:a", "2",
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
"-write_xing", "0" # No XING headers for chunks
])
# Get the encoded data

View file

@ -0,0 +1,89 @@
"""Temporary file writer for audio downloads"""
import os
import tempfile
from typing import Optional
import aiofiles
from fastapi import HTTPException
from loguru import logger
from ..core.config import settings
from ..core.paths import _scan_directories
class TempFileWriter:
"""Handles writing audio chunks to a temp file"""
def __init__(self, format: str):
"""Initialize temp file writer
Args:
format: Audio format extension (mp3, wav, etc)
"""
self.format = format
self.temp_file = None
self._finalized = False
async def __aenter__(self):
"""Async context manager entry"""
# Check temp dir size by scanning
total_size = 0
entries = await _scan_directories([settings.temp_file_dir])
for entry in entries:
stat = await aiofiles.os.stat(os.path.join(settings.temp_file_dir, entry))
total_size += stat.st_size
if total_size >= settings.max_temp_dir_size_mb * 1024 * 1024:
raise HTTPException(
status_code=507,
detail="Temporary storage full. Please try again later."
)
# Create temp file with proper extension
os.makedirs(settings.temp_file_dir, exist_ok=True)
temp = tempfile.NamedTemporaryFile(
dir=settings.temp_file_dir,
delete=False,
suffix=f".{self.format}",
mode='wb'
)
self.temp_file = await aiofiles.open(temp.name, mode='wb')
self.temp_path = temp.name
temp.close() # Close sync file, we'll use async version
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
try:
if self.temp_file and not self._finalized:
await self.temp_file.close()
self._finalized = True
except Exception as e:
logger.error(f"Error closing temp file: {e}")
async def write(self, chunk: bytes) -> None:
"""Write a chunk of audio data
Args:
chunk: Audio data bytes to write
"""
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
await self.temp_file.write(chunk)
await self.temp_file.flush()
async def finalize(self) -> str:
"""Close temp file and return download path
Returns:
Path to use for downloading the temp file
"""
if self._finalized:
raise RuntimeError("Temp file already finalized")
await self.temp_file.close()
self._finalized = True
return f"/download/{os.path.basename(self.temp_path)}"

View file

@ -8,10 +8,10 @@ from .phonemizer import phonemize
from .normalizer import normalize_text
from .vocabulary import tokenize
# Constants for chunk size optimization
TARGET_MIN_TOKENS = 300
TARGET_MAX_TOKENS = 400
ABSOLUTE_MAX_TOKENS = 500
# Target token ranges
TARGET_MIN = 200
TARGET_MAX = 350
ABSOLUTE_MAX = 500
def process_text_chunk(text: str, language: str = "a") -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization.
@ -48,10 +48,6 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
return tokens
def is_chunk_size_optimal(token_count: int) -> bool:
"""Check if chunk size is within optimal range."""
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging."""
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
@ -76,10 +72,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language)
# Target token ranges
TARGET_MIN = 300
TARGET_MAX = 400
ABSOLUTE_MAX = 500
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
@ -166,13 +158,23 @@ async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerat
yield chunk_text, clause_tokens
# Regular sentence handling
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= TARGET_MAX:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif current_count + count <= max_tokens:
# Accept slightly larger chunk if needed
elif current_count + count <= max_tokens and current_count < TARGET_MIN:
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count

View file

@ -46,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
default=True, # Default to streaming for OpenAI compatibility
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
)
return_download_link: bool = Field(
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)

View file

@ -77,9 +77,19 @@ export class App {
// Handle completion
this.audioService.addEventListener('complete', () => {
this.setGenerating(false);
this.showStatus('Preparing file...', 'info');
});
// Handle download ready
this.audioService.addEventListener('downloadReady', () => {
this.showStatus('Generation complete', 'success');
});
// Handle audio end
this.audioService.addEventListener('ended', () => {
this.playerState.setPlaying(false);
});
// Handle errors
this.audioService.addEventListener('error', (error) => {
this.showStatus('Error: ' + error.message, 'error');

View file

@ -9,7 +9,8 @@ export class AudioService {
this.minimumPlaybackSize = 50000; // 50KB minimum before playback
this.textLength = 0;
this.shouldAutoplay = false;
this.CHARS_PER_CHUNK = 600; // Estimated chars per chunk
this.CHARS_PER_CHUNK = 300; // Estimated chars per chunk
this.serverDownloadPath = null; // Server-side download path
}
async streamAudio(text, voice, speed, onProgress) {
@ -45,7 +46,8 @@ export class AudioService {
voice: voice,
response_format: 'mp3',
stream: true,
speed: speed
speed: speed,
return_download_link: true
}),
signal: this.controller.signal
});
@ -58,7 +60,7 @@ export class AudioService {
throw new Error(error.detail?.message || 'Failed to generate speech');
}
await this.setupAudioStream(response, onProgress, estimatedChunks);
await this.setupAudioStream(response.body, response, onProgress, estimatedChunks);
return this.audio;
} catch (error) {
this.cleanup();
@ -66,16 +68,21 @@ export class AudioService {
}
}
async setupAudioStream(response, onProgress, estimatedTotalSize) {
async setupAudioStream(stream, response, onProgress, estimatedTotalSize) {
this.audio = new Audio();
this.mediaSource = new MediaSource();
this.audio.src = URL.createObjectURL(this.mediaSource);
// Set up ended event handler
this.audio.addEventListener('ended', () => {
this.dispatchEvent('ended');
});
return new Promise((resolve, reject) => {
this.mediaSource.addEventListener('sourceopen', async () => {
try {
this.sourceBuffer = this.mediaSource.addSourceBuffer('audio/mpeg');
await this.processStream(response.body, onProgress, estimatedTotalSize);
await this.processStream(stream, response, onProgress, estimatedTotalSize);
resolve();
} catch (error) {
reject(error);
@ -84,11 +91,17 @@ export class AudioService {
});
}
async processStream(stream, onProgress, estimatedChunks) {
async processStream(stream, response, onProgress, estimatedChunks) {
const reader = stream.getReader();
let hasStartedPlaying = false;
let receivedChunks = 0;
// Check for download path in response headers
const downloadPath = response.headers.get('X-Download-Path');
if (downloadPath) {
this.serverDownloadPath = downloadPath;
}
try {
while (true) {
const {value, done} = await reader.read();
@ -245,6 +258,7 @@ export class AudioService {
this.sourceBuffer = null;
this.chunks = [];
this.textLength = 0;
this.serverDownloadPath = null;
// Force a hard refresh of the page to ensure clean state
window.location.reload();
@ -277,17 +291,25 @@ export class AudioService {
this.sourceBuffer = null;
this.chunks = [];
this.textLength = 0;
this.serverDownloadPath = null;
}
getDownloadUrl() {
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
// Get the buffered data from MediaSource
const buffered = this.sourceBuffer.buffered;
if (buffered.length === 0) return null;
// Create blob from the original chunks
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
getDownloadUrl() {
// Check for server-side download link first
const downloadPath = this.serverDownloadPath;
if (downloadPath) {
return downloadPath;
}
// Fall back to client-side blob URL
if (!this.audio || !this.sourceBuffer || this.chunks.length === 0) return null;
// Get the buffered data from MediaSource
const buffered = this.sourceBuffer.buffered;
if (buffered.length === 0) return null;
// Create blob from the original chunks
const blob = new Blob(this.chunks, { type: 'audio/mpeg' });
return URL.createObjectURL(blob);
return URL.createObjectURL(blob);
}
}