diff --git a/Test Threads.py b/Test Threads.py new file mode 100644 index 0000000..1b5f6c1 --- /dev/null +++ b/Test Threads.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Compatible with both Windows and Linux +""" +Kokoro TTS Race Condition Test + +This script creates multiple concurrent requests to a Kokoro TTS service +to reproduce a race condition where audio outputs don't match the requested text. +Each thread generates a simple numbered sentence, which should make mismatches +easy to identify through listening. + +To run: +python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880 +""" + +import argparse +import base64 +import concurrent.futures +import json +import os +import requests +import time +import wave +import sys +from pathlib import Path + + +def setup_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions") + parser.add_argument("--url", default="http://localhost:8880", + help="Base URL of the Kokoro TTS service") + parser.add_argument("--threads", type=int, default=8, + help="Number of concurrent threads to use") + parser.add_argument("--iterations", type=int, default=5, + help="Number of iterations per thread") + parser.add_argument("--voice", default="af_heart", + help="Voice to use for TTS") + parser.add_argument("--output-dir", default="./tts_test_output", + help="Directory to save output files") + parser.add_argument("--debug", action="store_true", + help="Enable debug logging") + return parser.parse_args() + + +def generate_test_sentence(thread_id, iteration): + """Generate a simple test sentence with numbers to make mismatches easily identifiable""" + return f"This is test sentence number {thread_id}-{iteration}. " \ + f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}." + + +def log_message(message, debug=False, is_error=False): + """Log messages with timestamps""" + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + prefix = "[ERROR]" if is_error else "[INFO]" + if is_error or debug: + print(f"{prefix} {timestamp} - {message}") + sys.stdout.flush() # Ensure logs are visible in Docker output + + +def request_tts(url, test_id, text, voice, output_dir, debug=False): + """Request TTS from the Kokoro API and save the WAV output""" + start_time = time.time() + output_file = os.path.join(output_dir, f"test_{test_id}.wav") + text_file = os.path.join(output_dir, f"test_{test_id}.txt") + + # Log output paths for debugging + log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug) + log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug) + + # Save the text for later comparison + try: + with open(text_file, "w") as f: + f.write(text) + log_message(f"Thread {test_id}: Successfully saved text file", debug) + except Exception as e: + log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True) + + # Make the TTS request + try: + log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug) + + response = requests.post( + f"{url}/v1/audio/speech", + json={ + "model": "kokoro", + "input": text, + "voice": voice, + "response_format": "wav" + }, + headers={"Accept": "audio/wav"}, + timeout=60 # Increase timeout to 60 seconds + ) + + log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug) + log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug) + log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug) + + if response.status_code != 200: + log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True) + return False + + # Check if we got valid audio data + if len(response.content) < 100: # Sanity check - WAV files should be larger than this + log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True) + log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True) + return False + + # Save the audio output with explicit error handling + try: + with open(output_file, "wb") as f: + bytes_written = f.write(response.content) + log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug) + + # Verify the WAV file exists and has content + if os.path.exists(output_file): + file_size = os.path.getsize(output_file) + log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug) + + # Validate WAV file by reading its headers + try: + with wave.open(output_file, 'rb') as wav_file: + channels = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + framerate = wav_file.getframerate() + frames = wav_file.getnframes() + log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, " + f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug) + except Exception as wav_error: + log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True) + else: + log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True) + except Exception as save_error: + log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True) + return False + + end_time = time.time() + log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug) + return True + + except requests.exceptions.Timeout: + log_message(f"Thread {test_id}: Request timed out", debug, is_error=True) + return False + except Exception as e: + log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True) + return False + + +def worker_task(thread_id, args): + """Worker task for each thread""" + for i in range(args.iterations): + iteration = i + 1 + test_id = f"{thread_id:02d}_{iteration:02d}" + text = generate_test_sentence(thread_id, iteration) + success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug) + + if not success: + log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True) + + # Small delay between iterations to avoid overwhelming the API + time.sleep(0.1) + + +def run_test(args): + """Run the test with the specified parameters""" + # Ensure output directory exists and check permissions + os.makedirs(args.output_dir, exist_ok=True) + + # Test write access to the output directory + test_file = os.path.join(args.output_dir, "write_test.txt") + try: + with open(test_file, "w") as f: + f.write("Testing write access\n") + os.remove(test_file) + log_message(f"Successfully verified write access to output directory: {args.output_dir}") + except Exception as e: + log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True) + log_message(f"Current directory: {os.getcwd()}", is_error=True) + log_message(f"Directory contents: {os.listdir('.')}", is_error=True) + + # Test connection to Kokoro TTS service + try: + response = requests.get(f"{args.url}/health", timeout=5) + if response.status_code == 200: + log_message(f"Successfully connected to Kokoro TTS service at {args.url}") + else: + log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True) + except Exception as e: + log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True) + + # Record start time + start_time = time.time() + log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread") + + # Create and start worker threads + with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor: + futures = [] + for thread_id in range(1, args.threads + 1): + futures.append(executor.submit(worker_task, thread_id, args)) + + # Wait for all tasks to complete + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True) + + # Record end time and print summary + end_time = time.time() + total_time = end_time - start_time + total_requests = args.threads * args.iterations + log_message(f"Test completed in {total_time:.2f} seconds") + log_message(f"Total requests: {total_requests}") + log_message(f"Average time per request: {total_time / total_requests:.2f} seconds") + log_message(f"Requests per second: {total_requests / total_time:.2f}") + log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}") + log_message("To verify, listen to the audio files and check if they match the text files") + log_message("If you hear audio describing a different test number than the filename, you've found a race condition") + + +def analyze_audio_files(output_dir): + """Provide summary of the generated audio files""" + # Look for both WAV and TXT files + wav_files = list(Path(output_dir).glob("*.wav")) + txt_files = list(Path(output_dir).glob("*.txt")) + + log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files") + + if len(wav_files) == 0: + log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True) + log_message("Check the connection to the TTS service and the response status codes above.", is_error=True) + + file_stats = [] + for wav_path in wav_files: + try: + with wave.open(str(wav_path), 'rb') as wav_file: + frames = wav_file.getnframes() + rate = wav_file.getframerate() + duration = frames / rate + + # Get corresponding text + text_path = wav_path.with_suffix('.txt') + if text_path.exists(): + with open(text_path, 'r') as text_file: + text = text_file.read().strip() + else: + text = "N/A" + + file_stats.append({ + 'filename': wav_path.name, + 'duration': duration, + 'text': text + }) + except Exception as e: + log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True) + + # Print summary table + if file_stats: + log_message("\nAudio File Summary:") + log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}") + log_message("-" * 92) + for stat in file_stats: + log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}") + + # List missing WAV files where text files exist + missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files) + if missing_wavs: + log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True) + for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability + log_message(f" - {stem}.txt (no WAV file)", is_error=True) + if len(missing_wavs) > 10: + log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True) + + +if __name__ == "__main__": + args = setup_args() + run_test(args) + analyze_audio_files(args.output_dir) + + log_message("\nNext Steps:") + log_message("1. Listen to the generated audio files") + log_message("2. Verify if each audio correctly says its ID number") + log_message("3. Check for any mismatches between the audio content and the text files") + log_message("4. If mismatches are found, you've successfully reproduced the race condition") \ No newline at end of file diff --git a/api/src/routers/development.py b/api/src/routers/development.py index ec5596e..96f94ea 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -180,10 +180,11 @@ async def create_captioned_speech( "pcm": "audio/pcm", }.get(request.response_format, f"audio/{request.response_format}") + writer = StreamingAudioWriter(request.response_format, sample_rate=24000) # Check if streaming is requested (default for OpenAI client) if request.stream: # Create generator but don't start it yet - generator = stream_audio_chunks(tts_service, request, client_request) + generator = stream_audio_chunks(tts_service, request, client_request, writer) # If download link requested, wrap generator with temp file writer if request.return_download_link: @@ -284,6 +285,7 @@ async def create_captioned_speech( audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, return_timestamps=request.return_timestamps, normalization_options=request.normalization_options, @@ -294,6 +296,7 @@ async def create_captioned_speech( audio_data, 24000, request.response_format, + writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False, @@ -304,6 +307,7 @@ async def create_captioned_speech( AudioChunk(np.array([], dtype=np.int16)), 24000, request.response_format, + writer, is_first_chunk=False, is_last_chunk=True, ) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 10e32bf..6d37e71 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -10,11 +10,12 @@ from urllib import response import aiofiles import numpy as np +from ..services.streaming_audio_writer import StreamingAudioWriter import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import FileResponse, StreamingResponse from loguru import logger -from structures.schemas import CaptionedSpeechRequest +from ..structures.schemas import CaptionedSpeechRequest from ..core.config import settings from ..inference.base import AudioChunk @@ -79,9 +80,7 @@ def get_model_name(model: str) -> str: return base_name + ".pth" -async def process_and_validate_voices( - voice_input: Union[str, List[str]], tts_service: TTSService -) -> str: +async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str: """Process voice input, handling both string and list formats Returns: @@ -90,73 +89,59 @@ async def process_and_validate_voices( voices = [] # Convert input to list of voices if isinstance(voice_input, str): - voice_input=voice_input.replace(" ","").strip() + voice_input = voice_input.replace(" ", "").strip() if voice_input[-1] in "+-" or voice_input[0] in "+-": - raise ValueError( - f"Voice combination contains empty combine items" - ) + raise ValueError(f"Voice combination contains empty combine items") if re.search(r"[+-]{2,}", voice_input) is not None: - raise ValueError( - f"Voice combination contains empty combine items" - ) + raise ValueError(f"Voice combination contains empty combine items") voices = re.split(r"([-+])", voice_input) else: - voices = [[item,"+"] for item in voice_input][:-1] - + voices = [[item, "+"] for item in voice_input][:-1] + available_voices = await tts_service.list_voices() - for voice_index in range(0,len(voices), 2): - + for voice_index in range(0, len(voices), 2): mapped_voice = voices[voice_index].split("(") mapped_voice = list(map(str.strip, mapped_voice)) if len(mapped_voice) > 2: - raise ValueError( - f"Voice '{voices[voice_index]}' contains too many weight items" - ) - + raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items") + if mapped_voice.count(")") > 1: - raise ValueError( - f"Voice '{voices[voice_index]}' contains too many weight items" - ) - + raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items") + mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0]) if mapped_voice[0] not in available_voices: - raise ValueError( - f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}" - ) + raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}") voices[voice_index] = "(".join(mapped_voice) - + return "".join(voices) -async def stream_audio_chunks( - tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request -) -> AsyncGenerator[AudioChunk, None]: + +async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]: """Stream audio chunks as they're generated with client disconnect handling""" voice_name = await process_and_validate_voices(request.voice, tts_service) - unique_properties={ - "return_timestamps":False - } + unique_properties = {"return_timestamps": False} if hasattr(request, "return_timestamps"): - unique_properties["return_timestamps"]=request.return_timestamps - + unique_properties["return_timestamps"] = request.return_timestamps + try: logger.info(f"Starting audio generation with lang_code: {request.lang_code}") async for chunk_data in tts_service.generate_audio_stream( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, output_format=request.response_format, lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(), normalization_options=request.normalization_options, return_timestamps=unique_properties["return_timestamps"], ): - # Check if client is still connected is_disconnected = client_request.is_disconnected if callable(is_disconnected): @@ -174,7 +159,6 @@ async def stream_audio_chunks( @router.post("/audio/speech") async def create_speech( - request: OpenAISpeechRequest, client_request: Request, x_raw_response: str = Header(None, alias="x-raw-response"), @@ -206,10 +190,12 @@ async def create_speech( "pcm": "audio/pcm", }.get(request.response_format, f"audio/{request.response_format}") + writer = StreamingAudioWriter(request.response_format, sample_rate=24000) + # Check if streaming is requested (default for OpenAI client) if request.stream: # Create generator but don't start it yet - generator = stream_audio_chunks(tts_service, request, client_request) + generator = stream_audio_chunks(tts_service, request, client_request, writer) # If download link requested, wrap generator with temp file writer if request.return_download_link: @@ -239,9 +225,9 @@ async def create_speech( async for chunk_data in generator: if chunk_data.output: # Skip empty chunks await temp_writer.write(chunk_data.output) - #if return_json: + # if return_json: # yield chunk, chunk_data - #else: + # else: yield chunk_data.output # Finalize the temp file @@ -256,9 +242,7 @@ async def create_speech( await temp_writer.__aexit__(None, None, None) # Stream with temp file writing - return StreamingResponse( - dual_output(), media_type=content_type, headers=headers - ) + return StreamingResponse(dual_output(), media_type=content_type, headers=headers) async def single_output(): try: @@ -269,7 +253,7 @@ async def create_speech( except Exception as e: logger.error(f"Error in single output streaming: {e}") raise - + # Standard streaming without download link return StreamingResponse( single_output(), @@ -283,40 +267,36 @@ async def create_speech( ) else: headers = { - "Content-Disposition": f"attachment; filename=speech.{request.response_format}", - "Cache-Control": "no-cache", # Prevent caching - } + "Content-Disposition": f"attachment; filename=speech.{request.response_format}", + "Cache-Control": "no-cache", # Prevent caching + } # Generate complete audio using public interface audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, normalization_options=request.normalization_options, lang_code=request.lang_code, ) - audio_data = await AudioService.convert_audio( - audio_data, - 24000, - request.response_format, - is_first_chunk=True, - is_last_chunk=False, - trim_audio=False - ) - + audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False) + # Convert to requested format with proper finalization final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), 24000, request.response_format, + writer, is_first_chunk=False, is_last_chunk=True, ) - output=audio_data.output + final.output + output = audio_data.output + final.output if request.return_download_link: from ..services.temp_manager import TempFileWriter + # Use download_format if specified, otherwise use response_format output_format = request.download_format or request.response_format temp_writer = TempFileWriter(output_format) @@ -390,9 +370,7 @@ async def download_audio_file(filename: str): from ..core.paths import _find_file, get_content_type # Search for file in temp directory - file_path = await _find_file( - filename=filename, search_paths=[settings.temp_file_dir] - ) + file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir]) # Get content type from path helper content_type = await get_content_type(file_path) @@ -425,30 +403,12 @@ async def list_models(): try: # Create standard model list models = [ - { - "id": "tts-1", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - { - "id": "tts-1-hd", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - { - "id": "kokoro", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - } + {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, ] - - return { - "object": "list", - "data": models - } + + return {"object": "list", "data": models} except Exception as e: logger.error(f"Error listing models: {str(e)}") raise HTTPException( @@ -460,43 +420,22 @@ async def list_models(): }, ) + @router.get("/models/{model}") async def retrieve_model(model: str): """Retrieve a specific model""" try: # Define available models models = { - "tts-1": { - "id": "tts-1", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - "tts-1-hd": { - "id": "tts-1-hd", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - "kokoro": { - "id": "kokoro", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - } + "tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + "tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + "kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, } - + # Check if requested model exists if model not in models: - raise HTTPException( - status_code=404, - detail={ - "error": "model_not_found", - "message": f"Model '{model}' not found", - "type": "invalid_request_error" - } - ) - + raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"}) + # Return the specific model return models[model] except HTTPException: @@ -512,6 +451,7 @@ async def retrieve_model(model: str): }, ) + @router.get("/audio/voices") async def list_voices(): """List all available voices for text-to-speech""" @@ -579,9 +519,7 @@ async def combine_voices(request: Union[str, List[str]]): available_voices = await tts_service.list_voices() for voice in voices: if voice not in available_voices: - raise ValueError( - f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}" - ) + raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}") # Combine voices combined_tensor = await tts_service.combine_voices(voices=voices) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index d1b412e..d0aed80 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -115,6 +115,7 @@ class AudioService: audio_chunk: AudioChunk, sample_rate: int, output_format: str, + writer: StreamingAudioWriter, speed: float = 1, chunk_text: str = "", is_first_chunk: bool = True, @@ -153,28 +154,30 @@ class AudioService: audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer) # Get or create format-specific writer - writer_key = f"{output_format}_{sample_rate}" + """writer_key = f"{output_format}_{sample_rate}" if is_first_chunk or writer_key not in AudioService._writers: AudioService._writers[writer_key] = StreamingAudioWriter( output_format, sample_rate ) - writer = AudioService._writers[writer_key] + writer = AudioService._writers[writer_key]""" # Write audio data first if len(audio_chunk.audio) > 0: chunk_data = writer.write_chunk(audio_chunk.audio) + + # Then finalize if this is the last chunk if is_last_chunk: final_data = writer.write_chunk(finalize=True) - del AudioService._writers[writer_key] + if final_data: audio_chunk.output=final_data return audio_chunk - + if chunk_data: - audio_chunk.output=chunk_data + audio_chunk.output=chunk_data return audio_chunk except Exception as e: diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index cb9f655..1b31eab 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -8,6 +8,7 @@ import time from typing import AsyncGenerator, List, Optional, Tuple, Union import numpy as np +from .streaming_audio_writer import StreamingAudioWriter import torch from kokoro import KPipeline from loguru import logger @@ -50,6 +51,7 @@ class TTSService: voice_name: str, voice_path: str, speed: float, + writer: StreamingAudioWriter, output_format: Optional[str] = None, is_first: bool = False, is_last: bool = False, @@ -64,12 +66,13 @@ class TTSService: if is_last: # Skip format conversion for raw audio mode if not output_format: - yield AudioChunk(np.array([], dtype=np.int16),output=b'') + yield AudioChunk(np.array([], dtype=np.int16), output=b"") return chunk_data = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking 24000, output_format, + writer, speed, "", is_first_chunk=False, @@ -88,7 +91,7 @@ class TTSService: # Generate audio using pre-warmed model if isinstance(backend, KokoroV1): - chunk_index=0 + chunk_index = 0 # For Kokoro V1, pass text and voice info with lang_code async for chunk_data in self.model_manager.generate( chunk_text, @@ -104,6 +107,7 @@ class TTSService: chunk_data, 24000, output_format, + writer, speed, chunk_text, is_first_chunk=is_first and chunk_index == 0, @@ -114,22 +118,13 @@ class TTSService: except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - chunk_data = AudioService.trim_audio(chunk_data, - chunk_text, - speed, - is_last, - normalizer) + chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) yield chunk_data - chunk_index+=1 + chunk_index += 1 else: - # For legacy backends, load voice tensor - voice_tensor = await self._voice_manager.load_voice( - voice_name, device=backend.device - ) - chunk_data = await self.model_manager.generate( - tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps - ) + voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device) + chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps) if chunk_data.audio is None: logger.error("Model generated None for audio chunk") @@ -146,6 +141,7 @@ class TTSService: chunk_data, 24000, output_format, + writer, speed, chunk_text, is_first_chunk=is_first, @@ -156,11 +152,7 @@ class TTSService: except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - trimmed = AudioService.trim_audio(chunk_data, - chunk_text, - speed, - is_last, - normalizer) + trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) yield trimmed except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") @@ -169,7 +161,7 @@ class TTSService: # Check if the path is None and raise a ValueError if it is not if not path: raise ValueError(f"Voice not found at path: {path}") - + logger.debug(f"Loading voice tensor from path: {path}") return torch.load(path, map_location="cpu") * weight @@ -191,10 +183,8 @@ class TTSService: # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it if len(split_voice) == 1: - # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True: - path = await self._voice_manager.get_voice_path(voice) if not path: raise RuntimeError(f"Voice not found: {voice}") @@ -202,8 +192,8 @@ class TTSService: return voice, path total_weight = 0 - - for voice_index in range(0,len(split_voice),2): + + for voice_index in range(0, len(split_voice), 2): voice_object = split_voice[voice_index] if "(" in voice_object and ")" in voice_object: @@ -215,7 +205,7 @@ class TTSService: total_weight += voice_weight split_voice[voice_index] = (voice_name, voice_weight) - + # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1 if settings.voice_weight_normalization == False: total_weight = 1 @@ -225,9 +215,9 @@ class TTSService: combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight) # Loop through each + or - in split_voice so they can be applied to combined voice - for operation_index in range(1,len(split_voice) - 1, 2): + for operation_index in range(1, len(split_voice) - 1, 2): # Get the voice path of the voice 1 index ahead of the operator - path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0]) + path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0]) voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight) # Either add or subtract the voice from the current combined voice @@ -235,7 +225,7 @@ class TTSService: combined_tensor += voice_tensor else: combined_tensor -= voice_tensor - + # Save the new combined voice so it can be loaded latter temp_dir = tempfile.gettempdir() combined_path = os.path.join(temp_dir, f"{voice}.pt") @@ -250,6 +240,7 @@ class TTSService: self, text: str, voice: str, + writer: StreamingAudioWriter, speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, @@ -259,7 +250,7 @@ class TTSService: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() chunk_index = 0 - current_offset=0.0 + current_offset = 0.0 try: # Get backend backend = self.model_manager.get_backend() @@ -270,13 +261,10 @@ class TTSService: # Use provided lang_code or determine from voice name pipeline_lang_code = lang_code if lang_code else voice[:1].lower() - logger.info( - f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" - ) - - + logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream") + # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options): + async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options): try: # Process audio for chunk async for chunk_data in self._process_chunk( @@ -285,6 +273,7 @@ class TTSService: voice_name, # Pass voice name voice_path, # Pass voice path speed, + writer, output_format, is_first=(chunk_index == 0), is_last=False, # We'll update the last chunk later @@ -294,23 +283,19 @@ class TTSService: ): if chunk_data.word_timestamps is not None: for timestamp in chunk_data.word_timestamps: - timestamp.start_time+=current_offset - timestamp.end_time+=current_offset - - current_offset+=len(chunk_data.audio) / 24000 - + timestamp.start_time += current_offset + timestamp.end_time += current_offset + + current_offset += len(chunk_data.audio) / 24000 + if chunk_data.output is not None: yield chunk_data - + else: - logger.warning( - f"No audio generated for chunk: '{chunk_text[:100]}...'" - ) + logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") chunk_index += 1 except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" - ) + logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") continue # Only finalize if we successfully processed at least one chunk @@ -323,6 +308,7 @@ class TTSService: voice_name, voice_path, speed, + writer, output_format, is_first=False, is_last=True, # Signal this is the last chunk @@ -337,32 +323,30 @@ class TTSService: except Exception as e: logger.error(f"Error in phoneme audio generation: {str(e)}") raise e - async def generate_audio( self, text: str, voice: str, + writer: StreamingAudioWriter, speed: float = 1.0, return_timestamps: bool = False, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), lang_code: Optional[str] = None, ) -> AudioChunk: """Generate complete audio for text using streaming internally.""" - audio_data_chunks=[] - + audio_data_chunks = [] + try: - async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None): + async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None): if len(audio_stream_data.audio) > 0: audio_data_chunks.append(audio_stream_data) - - combined_audio_data=AudioChunk.combine(audio_data_chunks) + combined_audio_data = AudioChunk.combine(audio_data_chunks) return combined_audio_data except Exception as e: logger.error(f"Error in audio generation: {str(e)}") raise - async def combine_voices(self, voices: List[str]) -> torch.Tensor: """Combine multiple voices. @@ -406,15 +390,11 @@ class TTSService: result = None # Use provided lang_code or determine from voice name pipeline_lang_code = lang_code if lang_code else voice[:1].lower() - logger.info( - f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline" - ) + logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline") try: # Use backend's pipeline management - for r in backend._get_pipeline( - pipeline_lang_code - ).generate_from_tokens( + for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens( tokens=phonemes, # Pass raw phonemes string voice=voice_path, speed=speed, @@ -432,9 +412,7 @@ class TTSService: processing_time = time.time() - start_time return result.audio.numpy(), processing_time else: - raise ValueError( - "Phoneme generation only supported with Kokoro V1 backend" - ) + raise ValueError("Phoneme generation only supported with Kokoro V1 backend") except Exception as e: logger.error(f"Error in phoneme audio generation: {str(e)}") diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index ca2d25d..7ddc1b5 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -7,7 +7,7 @@ import pytest from api.src.services.audio import AudioNormalizer, AudioService from api.src.inference.base import AudioChunk - +from api.src.services.streaming_audio_writer import StreamingAudioWriter @pytest.fixture(autouse=True) def mock_settings(): """Mock settings for all tests""" @@ -30,9 +30,11 @@ def sample_audio(): async def test_convert_to_wav(sample_audio): """Test converting to WAV format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=False + AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -46,8 +48,11 @@ async def test_convert_to_wav(sample_audio): async def test_convert_to_mp3(sample_audio): """Test converting to MP3 format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("writer", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "mp3" + AudioChunk(audio_data), sample_rate, "mp3", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -60,8 +65,11 @@ async def test_convert_to_mp3(sample_audio): async def test_convert_to_opus(sample_audio): """Test converting to Opus format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("opus", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "opus" + AudioChunk(audio_data), sample_rate, "opus",writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -74,8 +82,11 @@ async def test_convert_to_opus(sample_audio): async def test_convert_to_flac(sample_audio): """Test converting to FLAC format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("flac", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "flac" + AudioChunk(audio_data), sample_rate, "flac", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -88,8 +99,11 @@ async def test_convert_to_flac(sample_audio): async def test_convert_to_aac(sample_audio): """Test converting to M4A format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("aac", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "aac" + AudioChunk(audio_data), sample_rate, "aac", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -102,8 +116,11 @@ async def test_convert_to_aac(sample_audio): async def test_convert_to_pcm(sample_audio): """Test converting to PCM format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("pcm", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "pcm" + AudioChunk(audio_data), sample_rate, "pcm", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -115,19 +132,25 @@ async def test_convert_to_pcm(sample_audio): async def test_convert_to_invalid_format_raises_error(sample_audio): """Test that converting to an invalid format raises an error""" audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="Unsupported format: invalid"): + StreamingAudioWriter("invalid", sample_rate=24000) + with pytest.raises(ValueError, match="Format invalid not supported"): - await AudioService.convert_audio(audio_data, sample_rate, "invalid") + await AudioService.convert_audio(audio_data, sample_rate, "invalid", None) @pytest.mark.asyncio async def test_normalization_wav(sample_audio): """Test that WAV output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) + # Create audio data outside int16 range large_audio = audio_data * 1e5 # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True + AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -138,10 +161,13 @@ async def test_normalization_wav(sample_audio): async def test_normalization_pcm(sample_audio): """Test that PCM output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("pcm", sample_rate=24000) + # Create audio data outside int16 range large_audio = audio_data * 1e5 audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "pcm" + AudioChunk(large_audio), sample_rate, "pcm", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -153,8 +179,11 @@ async def test_invalid_audio_data(): """Test handling of invalid audio data""" invalid_audio = np.array([]) # Empty array sample_rate = 24000 + + writer = StreamingAudioWriter("wav", sample_rate=24000) + with pytest.raises(ValueError): - await AudioService.convert_audio(invalid_audio, sample_rate, "wav") + await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer) @pytest.mark.asyncio @@ -164,8 +193,11 @@ async def test_different_sample_rates(sample_audio): sample_rates = [8000, 16000, 44100, 48000] for rate in sample_rates: + + writer = StreamingAudioWriter("wav", sample_rate=rate) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), rate, "wav", is_first_chunk=True + AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -176,15 +208,21 @@ async def test_different_sample_rates(sample_audio): async def test_buffer_position_after_conversion(sample_audio): """Test that buffer position is reset after writing""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) + # Write and finalize in one step for first conversion audio_chunk1 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True ) assert isinstance(audio_chunk1.output, bytes) assert isinstance(audio_chunk1, AudioChunk) # Convert again to ensure buffer was properly reset + + writer = StreamingAudioWriter("wav", sample_rate=24000) + audio_chunk2 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True ) assert isinstance(audio_chunk2.output, bytes) assert isinstance(audio_chunk2, AudioChunk)