From 3547d95ee6c8559fa2862dd0a65702c49ca6c538 Mon Sep 17 00:00:00 2001 From: remsky Date: Sat, 25 Jan 2025 05:25:13 -0700 Subject: [PATCH] -unified streaming implementation --- api/src/services/tts_service.py | 284 +++++++----------- .../benchmark_unified_streaming.py | 148 +++++++++ .../test_unified_streaming.py | 69 +++++ 3 files changed, 328 insertions(+), 173 deletions(-) create mode 100644 examples/streaming_refactor/benchmark_unified_streaming.py create mode 100644 examples/streaming_refactor/test_unified_streaming.py diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index a261fc6..4bae75c 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -2,7 +2,7 @@ import io import time -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, AsyncGenerator, Union import numpy as np import scipy.io.wavfile as wavfile @@ -25,112 +25,63 @@ class TTSService: _chunk_semaphore = asyncio.Semaphore(4) def __init__(self, output_dir: str = None): - """Initialize service. - - Args: - output_dir: Optional output directory for saving audio - """ + """Initialize service.""" self.output_dir = output_dir self.model_manager = None self._voice_manager = None @classmethod async def create(cls, output_dir: str = None) -> 'TTSService': - """Create and initialize TTSService instance. - - Args: - output_dir: Optional output directory for saving audio - - Returns: - Initialized TTSService instance - """ + """Create and initialize TTSService instance.""" service = cls(output_dir) - # Initialize managers service.model_manager = await get_model_manager() service._voice_manager = await get_voice_manager() return service - async def generate_audio( - self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True - ) -> Tuple[np.ndarray, float]: - """Generate audio for text. - - Args: - text: Input text - voice: Voice name - speed: Speed multiplier - stitch_long_output: Whether to stitch together long outputs - - Returns: - Audio samples and processing time - - Raises: - ValueError: If text is empty after preprocessing or no chunks generated - RuntimeError: If audio generation fails - """ - start_time = time.time() - voice_tensor = None - - try: - # Normalize text - normalized = normalize_text(text) - if not normalized: - raise ValueError("Text is empty after preprocessing") - text = str(normalized) - - # Get backend and load voice - backend = self.model_manager.get_backend() - voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - - # Get chunks using async generator - chunks = [] - async for chunk in chunker.split_text(text): - chunks.append(chunk) + async def _process_chunk( + self, + chunk: str, + voice_tensor: torch.Tensor, + speed: float, + output_format: Optional[str] = None, + is_first: bool = False, + is_last: bool = False, + normalizer: Optional[AudioNormalizer] = None, + ) -> Optional[Union[np.ndarray, bytes]]: + """Process a single text chunk into audio.""" + async with self._chunk_semaphore: + try: + tokens = process_text(chunk) + if not tokens: + return None - if not chunks: - raise ValueError("No text chunks to process") - - # Process chunk with concurrency control - async def process_chunk(chunk: str) -> Optional[np.ndarray]: - async with self._chunk_semaphore: - try: - tokens = process_text(chunk) - if not tokens: - return None - - # Generate audio - return await self.model_manager.generate( - tokens, - voice_tensor, - speed=speed - ) - except Exception as e: - logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") - return None - - # Process all chunks concurrently - chunk_results = await asyncio.gather(*[ - process_chunk(chunk) for chunk in chunks - ]) - - # Filter out None results and combine - audio_chunks = [chunk for chunk in chunk_results if chunk is not None] - if not audio_chunks: - raise ValueError("No audio chunks were generated successfully") - - # Combine chunks - audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] - processing_time = time.time() - start_time - return audio, processing_time - - except Exception as e: - logger.error(f"Error in audio generation: {str(e)}") - raise - finally: - # Always clean up voice tensor - if voice_tensor is not None: - del voice_tensor - torch.cuda.empty_cache() + # Generate audio using pre-warmed model + chunk_audio = await self.model_manager.generate( + tokens, + voice_tensor, + speed=speed + ) + + if chunk_audio is None: + return None + + # For streaming, convert to bytes + if output_format: + return await AudioService.convert_audio( + chunk_audio, + 24000, + output_format, + is_first_chunk=is_first, + normalizer=normalizer, + is_last_chunk=is_last, + stream=True + ) + + return chunk_audio + + except Exception as e: + logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") + return None async def generate_audio_stream( self, @@ -138,21 +89,12 @@ class TTSService: voice: str, speed: float = 1.0, output_format: str = "wav", - ): - """Generate and stream audio chunks. - - Args: - text: Input text - voice: Voice name - speed: Speed multiplier - output_format: Output audio format - - Yields: - Audio chunks as bytes - """ - # Setup audio processing + ) -> AsyncGenerator[bytes, None]: + """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() voice_tensor = None + pending_results = {} + next_index = 0 try: # Normalize text @@ -161,11 +103,11 @@ class TTSService: raise ValueError("Text is empty after preprocessing") text = str(normalized) - # Get backend and load voice + # Get backend and load voice (should be fast if cached) backend = self.model_manager.get_backend() voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - # Get chunks using async generator + # Process chunks with semaphore limiting concurrency chunks = [] async for chunk in chunker.split_text(text): chunks.append(chunk) @@ -173,84 +115,80 @@ class TTSService: if not chunks: raise ValueError("No text chunks to process") - # Process chunk with concurrency control - async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]: - async with self._chunk_semaphore: - try: - tokens = process_text(chunk) - if not tokens: - return None - - # Generate audio - chunk_audio = await self.model_manager.generate( - tokens, - voice_tensor, - speed=speed - ) - - if chunk_audio is not None: - # Convert to bytes - return await AudioService.convert_audio( - chunk_audio, - 24000, - output_format, - is_first_chunk=is_first, - normalizer=stream_normalizer, - is_last_chunk=is_last, - stream=True - ) - except Exception as e: - logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}") - return None - # Create tasks for all chunks tasks = [ - process_chunk(chunk, i==0, i==len(chunks)-1) + asyncio.create_task( + self._process_chunk( + chunk, + voice_tensor, + speed, + output_format, + is_first=(i == 0), + is_last=(i == len(chunks) - 1), + normalizer=stream_normalizer + ) + ) for i, chunk in enumerate(chunks) ] - # Process chunks concurrently and yield results in order - for chunk_bytes in await asyncio.gather(*tasks): - if chunk_bytes is not None: - yield chunk_bytes + # Process chunks and maintain order + for i, task in enumerate(tasks): + result = await task + + if i == next_index and result is not None: + # If this is the next chunk we need, yield it + yield result + next_index += 1 + + # Check if we have any subsequent chunks ready + while next_index in pending_results: + result = pending_results.pop(next_index) + if result is not None: + yield result + next_index += 1 + else: + # Store out-of-order result + pending_results[i] = result except Exception as e: logger.error(f"Error in audio generation stream: {str(e)}") raise finally: - # Always clean up voice tensor if voice_tensor is not None: del voice_tensor torch.cuda.empty_cache() - async def combine_voices(self, voices: List[str]) -> str: - """Combine multiple voices. + async def generate_audio( + self, text: str, voice: str, speed: float = 1.0 + ) -> Tuple[np.ndarray, float]: + """Generate complete audio for text using streaming internally.""" + start_time = time.time() + chunks = [] - Args: - voices: List of voice names - - Returns: - Name of combined voice - """ + try: + # Use streaming generator but collect all chunks + async for chunk in self.generate_audio_stream( + text, voice, speed, output_format=None + ): + if chunk is not None: + chunks.append(chunk) + + if not chunks: + raise ValueError("No audio chunks were generated successfully") + + # Combine chunks + audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0] + processing_time = time.time() - start_time + return audio, processing_time + + except Exception as e: + logger.error(f"Error in audio generation: {str(e)}") + raise + + async def combine_voices(self, voices: List[str]) -> str: + """Combine multiple voices.""" return await self._voice_manager.combine_voices(voices) async def list_voices(self) -> List[str]: - """List available voices. - - Returns: - List of voice names - """ - return await self._voice_manager.list_voices() - - def _audio_to_bytes(self, audio: np.ndarray) -> bytes: - """Convert audio to WAV bytes. - - Args: - audio: Audio samples - - Returns: - WAV bytes - """ - buffer = io.BytesIO() - wavfile.write(buffer, 24000, audio) - return buffer.getvalue() \ No newline at end of file + """List available voices.""" + return await self._voice_manager.list_voices() \ No newline at end of file diff --git a/examples/streaming_refactor/benchmark_unified_streaming.py b/examples/streaming_refactor/benchmark_unified_streaming.py new file mode 100644 index 0000000..369fb66 --- /dev/null +++ b/examples/streaming_refactor/benchmark_unified_streaming.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +"""Benchmark script for unified streaming implementation""" + +import asyncio +import time +from pathlib import Path +from typing import List, Tuple + +from openai import OpenAI +import numpy as np +import matplotlib.pyplot as plt + +# Initialize OpenAI client +client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed") + +TEST_TEXTS = { + "short": "The quick brown fox jumps over the lazy dog.", + "medium": """In a bustling city, life moves at a rapid pace. + People hurry along the sidewalks, while cars navigate + through the busy streets. The air is filled with the + sounds of urban activity.""", + "long": """The technological revolution has transformed how we live and work. + From artificial intelligence to renewable energy, innovations continue + to shape our future. As we face global challenges, scientific advances + offer new solutions. The intersection of technology and human creativity + drives progress forward, opening new possibilities for tomorrow.""" +} + +async def benchmark_streaming(text_name: str, text: str) -> Tuple[float, float, int]: + """Benchmark streaming performance + + Returns: + Tuple of (time to first byte, total time, total bytes) + """ + start_time = time.time() + total_bytes = 0 + first_byte_time = None + + with client.audio.speech.with_streaming_response.create( + model="kokoro", + voice="af_bella", + response_format="pcm", + input=text, + ) as response: + for chunk in response.iter_bytes(chunk_size=1024): + if first_byte_time is None: + first_byte_time = time.time() - start_time + total_bytes += len(chunk) + + total_time = time.time() - start_time + return first_byte_time, total_time, total_bytes + +async def benchmark_non_streaming(text_name: str, text: str) -> Tuple[float, int]: + """Benchmark non-streaming performance + + Returns: + Tuple of (total time, total bytes) + """ + start_time = time.time() + speech_file = Path(__file__).parent / f"non_stream_{text_name}.mp3" + + with client.audio.speech.with_streaming_response.create( + model="kokoro", + voice="af_bella", + input=text, + ) as response: + response.stream_to_file(speech_file) + + total_time = time.time() - start_time + total_bytes = speech_file.stat().st_size + return total_time, total_bytes + +def plot_results(results: dict): + """Plot benchmark results""" + plt.figure(figsize=(12, 6)) + + # Prepare data + text_lengths = [len(text) for text in TEST_TEXTS.values()] + streaming_times = [r["streaming"]["total_time"] for r in results.values()] + non_streaming_times = [r["non_streaming"]["total_time"] for r in results.values()] + first_byte_times = [r["streaming"]["first_byte_time"] for r in results.values()] + + # Plot times + x = np.arange(len(TEST_TEXTS)) + width = 0.25 + + plt.bar(x - width, streaming_times, width, label='Streaming Total Time') + plt.bar(x, non_streaming_times, width, label='Non-Streaming Total Time') + plt.bar(x + width, first_byte_times, width, label='Time to First Byte') + + plt.xlabel('Text Length (characters)') + plt.ylabel('Time (seconds)') + plt.title('Unified Streaming Performance Comparison') + plt.xticks(x, text_lengths) + plt.legend() + + # Save plot + plt.savefig(Path(__file__).parent / 'benchmark_results.png') + plt.close() + +async def main(): + """Run benchmarks""" + print("Starting unified streaming benchmarks...") + + results = {} + + for name, text in TEST_TEXTS.items(): + print(f"\nTesting {name} text ({len(text)} chars)...") + + # Test streaming + print("Running streaming test...") + first_byte_time, stream_total_time, stream_bytes = await benchmark_streaming(name, text) + + # Test non-streaming + print("Running non-streaming test...") + non_stream_total_time, non_stream_bytes = await benchmark_non_streaming(name, text) + + results[name] = { + "text_length": len(text), + "streaming": { + "first_byte_time": first_byte_time, + "total_time": stream_total_time, + "total_bytes": stream_bytes, + "throughput": stream_bytes / stream_total_time / 1024 # KB/s + }, + "non_streaming": { + "total_time": non_stream_total_time, + "total_bytes": non_stream_bytes, + "throughput": non_stream_bytes / non_stream_total_time / 1024 # KB/s + } + } + + # Print results for this test + print(f"\nResults for {name} text:") + print(f"Streaming:") + print(f" Time to first byte: {first_byte_time:.3f}s") + print(f" Total time: {stream_total_time:.3f}s") + print(f" Throughput: {stream_bytes/stream_total_time/1024:.1f} KB/s") + print(f"Non-streaming:") + print(f" Total time: {non_stream_total_time:.3f}s") + print(f" Throughput: {non_stream_bytes/non_stream_total_time/1024:.1f} KB/s") + + # Plot results + plot_results(results) + print("\nBenchmark results have been plotted to benchmark_results.png") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/streaming_refactor/test_unified_streaming.py b/examples/streaming_refactor/test_unified_streaming.py new file mode 100644 index 0000000..f9bc5e5 --- /dev/null +++ b/examples/streaming_refactor/test_unified_streaming.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +"""Test script for unified streaming implementation""" + +import asyncio +import time +from pathlib import Path + +from openai import OpenAI + +# Initialize OpenAI client +client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed") + +async def test_streaming_to_file(): + """Test streaming to file""" + print("\nTesting streaming to file...") + speech_file = Path(__file__).parent / "stream_output.mp3" + + start_time = time.time() + with client.audio.speech.with_streaming_response.create( + model="kokoro", + voice="af_bella", + input="Testing unified streaming implementation with a short phrase.", + ) as response: + response.stream_to_file(speech_file) + + print(f"Streaming to file completed in {(time.time() - start_time):.2f}s") + print(f"Output saved to: {speech_file}") + +async def test_streaming_chunks(): + """Test streaming chunks for real-time playback""" + print("\nTesting chunk streaming...") + + start_time = time.time() + chunk_count = 0 + total_bytes = 0 + + with client.audio.speech.with_streaming_response.create( + model="kokoro", + voice="af_bella", + response_format="pcm", + input="""This is a longer text to test chunk streaming. + We want to verify that the unified streaming implementation + works efficiently for both small and large inputs.""", + ) as response: + print(f"Time to first byte: {(time.time() - start_time):.3f}s") + + for chunk in response.iter_bytes(chunk_size=1024): + chunk_count += 1 + total_bytes += len(chunk) + # In real usage, this would go to audio playback + # For testing, we just count chunks and bytes + + total_time = time.time() - start_time + print(f"Received {chunk_count} chunks, {total_bytes} bytes") + print(f"Total streaming time: {total_time:.2f}s") + print(f"Average throughput: {total_bytes/total_time/1024:.1f} KB/s") + +async def main(): + """Run all tests""" + print("Starting unified streaming tests...") + + # Test both streaming modes + await test_streaming_to_file() + await test_streaming_chunks() + + print("\nAll tests completed!") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file