-unified streaming implementation

This commit is contained in:
remsky 2025-01-25 05:25:13 -07:00
parent 90c8f11111
commit 3547d95ee6
3 changed files with 328 additions and 173 deletions

View file

@ -2,7 +2,7 @@
import io import io
import time import time
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, AsyncGenerator, Union
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -25,134 +25,76 @@ class TTSService:
_chunk_semaphore = asyncio.Semaphore(4) _chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None): def __init__(self, output_dir: str = None):
"""Initialize service. """Initialize service."""
Args:
output_dir: Optional output directory for saving audio
"""
self.output_dir = output_dir self.output_dir = output_dir
self.model_manager = None self.model_manager = None
self._voice_manager = None self._voice_manager = None
@classmethod @classmethod
async def create(cls, output_dir: str = None) -> 'TTSService': async def create(cls, output_dir: str = None) -> 'TTSService':
"""Create and initialize TTSService instance. """Create and initialize TTSService instance."""
Args:
output_dir: Optional output directory for saving audio
Returns:
Initialized TTSService instance
"""
service = cls(output_dir) service = cls(output_dir)
# Initialize managers
service.model_manager = await get_model_manager() service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager() service._voice_manager = await get_voice_manager()
return service return service
async def generate_audio( async def _process_chunk(
self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True self,
) -> Tuple[np.ndarray, float]: chunk: str,
"""Generate audio for text. voice_tensor: torch.Tensor,
speed: float,
Args: output_format: Optional[str] = None,
text: Input text is_first: bool = False,
voice: Voice name is_last: bool = False,
speed: Speed multiplier normalizer: Optional[AudioNormalizer] = None,
stitch_long_output: Whether to stitch together long outputs ) -> Optional[Union[np.ndarray, bytes]]:
"""Process a single text chunk into audio."""
Returns:
Audio samples and processing time
Raises:
ValueError: If text is empty after preprocessing or no chunks generated
RuntimeError: If audio generation fails
"""
start_time = time.time()
voice_tensor = None
try:
# Normalize text
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get chunks using async generator
chunks = []
async for chunk in chunker.split_text(text):
chunks.append(chunk)
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
async with self._chunk_semaphore: async with self._chunk_semaphore:
try: try:
tokens = process_text(chunk) tokens = process_text(chunk)
if not tokens: if not tokens:
return None return None
# Generate audio # Generate audio using pre-warmed model
return await self.model_manager.generate( chunk_audio = await self.model_manager.generate(
tokens, tokens,
voice_tensor, voice_tensor,
speed=speed speed=speed
) )
if chunk_audio is None:
return None
# For streaming, convert to bytes
if output_format:
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=normalizer,
is_last_chunk=is_last,
stream=True
)
return chunk_audio
except Exception as e: except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
return None return None
# Process all chunks concurrently
chunk_results = await asyncio.gather(*[
process_chunk(chunk) for chunk in chunks
])
# Filter out None results and combine
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def generate_audio_stream( async def generate_audio_stream(
self, self,
text: str, text: str,
voice: str, voice: str,
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
): ) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks. """Generate and stream audio chunks."""
Args:
text: Input text
voice: Voice name
speed: Speed multiplier
output_format: Output audio format
Yields:
Audio chunks as bytes
"""
# Setup audio processing
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
voice_tensor = None voice_tensor = None
pending_results = {}
next_index = 0
try: try:
# Normalize text # Normalize text
@ -161,11 +103,11 @@ class TTSService:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
# Get backend and load voice # Get backend and load voice (should be fast if cached)
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get chunks using async generator # Process chunks with semaphore limiting concurrency
chunks = [] chunks = []
async for chunk in chunker.split_text(text): async for chunk in chunker.split_text(text):
chunks.append(chunk) chunks.append(chunk)
@ -173,84 +115,80 @@ class TTSService:
if not chunks: if not chunks:
raise ValueError("No text chunks to process") raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio
chunk_audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is not None:
# Convert to bytes
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=stream_normalizer,
is_last_chunk=is_last,
stream=True
)
except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
return None
# Create tasks for all chunks # Create tasks for all chunks
tasks = [ tasks = [
process_chunk(chunk, i==0, i==len(chunks)-1) asyncio.create_task(
self._process_chunk(
chunk,
voice_tensor,
speed,
output_format,
is_first=(i == 0),
is_last=(i == len(chunks) - 1),
normalizer=stream_normalizer
)
)
for i, chunk in enumerate(chunks) for i, chunk in enumerate(chunks)
] ]
# Process chunks concurrently and yield results in order # Process chunks and maintain order
for chunk_bytes in await asyncio.gather(*tasks): for i, task in enumerate(tasks):
if chunk_bytes is not None: result = await task
yield chunk_bytes
if i == next_index and result is not None:
# If this is the next chunk we need, yield it
yield result
next_index += 1
# Check if we have any subsequent chunks ready
while next_index in pending_results:
result = pending_results.pop(next_index)
if result is not None:
yield result
next_index += 1
else:
# Store out-of-order result
pending_results[i] = result
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}") logger.error(f"Error in audio generation stream: {str(e)}")
raise raise
finally: finally:
# Always clean up voice tensor
if voice_tensor is not None: if voice_tensor is not None:
del voice_tensor del voice_tensor
torch.cuda.empty_cache() torch.cuda.empty_cache()
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
chunks = []
try:
# Use streaming generator but collect all chunks
async for chunk in self.generate_audio_stream(
text, voice, speed, output_format=None
):
if chunk is not None:
chunks.append(chunk)
if not chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> str: async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices. """Combine multiple voices."""
Args:
voices: List of voice names
Returns:
Name of combined voice
"""
return await self._voice_manager.combine_voices(voices) return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]: async def list_voices(self) -> List[str]:
"""List available voices. """List available voices."""
Returns:
List of voice names
"""
return await self._voice_manager.list_voices() return await self._voice_manager.list_voices()
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
"""Convert audio to WAV bytes.
Args:
audio: Audio samples
Returns:
WAV bytes
"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()

View file

@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""Benchmark script for unified streaming implementation"""
import asyncio
import time
from pathlib import Path
from typing import List, Tuple
from openai import OpenAI
import numpy as np
import matplotlib.pyplot as plt
# Initialize OpenAI client
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
TEST_TEXTS = {
"short": "The quick brown fox jumps over the lazy dog.",
"medium": """In a bustling city, life moves at a rapid pace.
People hurry along the sidewalks, while cars navigate
through the busy streets. The air is filled with the
sounds of urban activity.""",
"long": """The technological revolution has transformed how we live and work.
From artificial intelligence to renewable energy, innovations continue
to shape our future. As we face global challenges, scientific advances
offer new solutions. The intersection of technology and human creativity
drives progress forward, opening new possibilities for tomorrow."""
}
async def benchmark_streaming(text_name: str, text: str) -> Tuple[float, float, int]:
"""Benchmark streaming performance
Returns:
Tuple of (time to first byte, total time, total bytes)
"""
start_time = time.time()
total_bytes = 0
first_byte_time = None
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text,
) as response:
for chunk in response.iter_bytes(chunk_size=1024):
if first_byte_time is None:
first_byte_time = time.time() - start_time
total_bytes += len(chunk)
total_time = time.time() - start_time
return first_byte_time, total_time, total_bytes
async def benchmark_non_streaming(text_name: str, text: str) -> Tuple[float, int]:
"""Benchmark non-streaming performance
Returns:
Tuple of (total time, total bytes)
"""
start_time = time.time()
speech_file = Path(__file__).parent / f"non_stream_{text_name}.mp3"
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text,
) as response:
response.stream_to_file(speech_file)
total_time = time.time() - start_time
total_bytes = speech_file.stat().st_size
return total_time, total_bytes
def plot_results(results: dict):
"""Plot benchmark results"""
plt.figure(figsize=(12, 6))
# Prepare data
text_lengths = [len(text) for text in TEST_TEXTS.values()]
streaming_times = [r["streaming"]["total_time"] for r in results.values()]
non_streaming_times = [r["non_streaming"]["total_time"] for r in results.values()]
first_byte_times = [r["streaming"]["first_byte_time"] for r in results.values()]
# Plot times
x = np.arange(len(TEST_TEXTS))
width = 0.25
plt.bar(x - width, streaming_times, width, label='Streaming Total Time')
plt.bar(x, non_streaming_times, width, label='Non-Streaming Total Time')
plt.bar(x + width, first_byte_times, width, label='Time to First Byte')
plt.xlabel('Text Length (characters)')
plt.ylabel('Time (seconds)')
plt.title('Unified Streaming Performance Comparison')
plt.xticks(x, text_lengths)
plt.legend()
# Save plot
plt.savefig(Path(__file__).parent / 'benchmark_results.png')
plt.close()
async def main():
"""Run benchmarks"""
print("Starting unified streaming benchmarks...")
results = {}
for name, text in TEST_TEXTS.items():
print(f"\nTesting {name} text ({len(text)} chars)...")
# Test streaming
print("Running streaming test...")
first_byte_time, stream_total_time, stream_bytes = await benchmark_streaming(name, text)
# Test non-streaming
print("Running non-streaming test...")
non_stream_total_time, non_stream_bytes = await benchmark_non_streaming(name, text)
results[name] = {
"text_length": len(text),
"streaming": {
"first_byte_time": first_byte_time,
"total_time": stream_total_time,
"total_bytes": stream_bytes,
"throughput": stream_bytes / stream_total_time / 1024 # KB/s
},
"non_streaming": {
"total_time": non_stream_total_time,
"total_bytes": non_stream_bytes,
"throughput": non_stream_bytes / non_stream_total_time / 1024 # KB/s
}
}
# Print results for this test
print(f"\nResults for {name} text:")
print(f"Streaming:")
print(f" Time to first byte: {first_byte_time:.3f}s")
print(f" Total time: {stream_total_time:.3f}s")
print(f" Throughput: {stream_bytes/stream_total_time/1024:.1f} KB/s")
print(f"Non-streaming:")
print(f" Total time: {non_stream_total_time:.3f}s")
print(f" Throughput: {non_stream_bytes/non_stream_total_time/1024:.1f} KB/s")
# Plot results
plot_results(results)
print("\nBenchmark results have been plotted to benchmark_results.png")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,69 @@
#!/usr/bin/env python3
"""Test script for unified streaming implementation"""
import asyncio
import time
from pathlib import Path
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
async def test_streaming_to_file():
"""Test streaming to file"""
print("\nTesting streaming to file...")
speech_file = Path(__file__).parent / "stream_output.mp3"
start_time = time.time()
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input="Testing unified streaming implementation with a short phrase.",
) as response:
response.stream_to_file(speech_file)
print(f"Streaming to file completed in {(time.time() - start_time):.2f}s")
print(f"Output saved to: {speech_file}")
async def test_streaming_chunks():
"""Test streaming chunks for real-time playback"""
print("\nTesting chunk streaming...")
start_time = time.time()
chunk_count = 0
total_bytes = 0
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input="""This is a longer text to test chunk streaming.
We want to verify that the unified streaming implementation
works efficiently for both small and large inputs.""",
) as response:
print(f"Time to first byte: {(time.time() - start_time):.3f}s")
for chunk in response.iter_bytes(chunk_size=1024):
chunk_count += 1
total_bytes += len(chunk)
# In real usage, this would go to audio playback
# For testing, we just count chunks and bytes
total_time = time.time() - start_time
print(f"Received {chunk_count} chunks, {total_bytes} bytes")
print(f"Total streaming time: {total_time:.2f}s")
print(f"Average throughput: {total_bytes/total_time/1024:.1f} KB/s")
async def main():
"""Run all tests"""
print("Starting unified streaming tests...")
# Test both streaming modes
await test_streaming_to_file()
await test_streaming_chunks()
print("\nAll tests completed!")
if __name__ == "__main__":
asyncio.run(main())