-unified streaming implementation

This commit is contained in:
remsky 2025-01-25 05:25:13 -07:00
parent 90c8f11111
commit 3547d95ee6
3 changed files with 328 additions and 173 deletions

View file

@ -2,7 +2,7 @@
import io
import time
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, AsyncGenerator, Union
import numpy as np
import scipy.io.wavfile as wavfile
@ -25,112 +25,63 @@ class TTSService:
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None):
"""Initialize service.
Args:
output_dir: Optional output directory for saving audio
"""
"""Initialize service."""
self.output_dir = output_dir
self.model_manager = None
self._voice_manager = None
@classmethod
async def create(cls, output_dir: str = None) -> 'TTSService':
"""Create and initialize TTSService instance.
Args:
output_dir: Optional output directory for saving audio
Returns:
Initialized TTSService instance
"""
"""Create and initialize TTSService instance."""
service = cls(output_dir)
# Initialize managers
service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager()
return service
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True
) -> Tuple[np.ndarray, float]:
"""Generate audio for text.
Args:
text: Input text
voice: Voice name
speed: Speed multiplier
stitch_long_output: Whether to stitch together long outputs
Returns:
Audio samples and processing time
Raises:
ValueError: If text is empty after preprocessing or no chunks generated
RuntimeError: If audio generation fails
"""
start_time = time.time()
voice_tensor = None
try:
# Normalize text
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get chunks using async generator
chunks = []
async for chunk in chunker.split_text(text):
chunks.append(chunk)
async def _process_chunk(
self,
chunk: str,
voice_tensor: torch.Tensor,
speed: float,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
) -> Optional[Union[np.ndarray, bytes]]:
"""Process a single text chunk into audio."""
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio
return await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
return None
# Process all chunks concurrently
chunk_results = await asyncio.gather(*[
process_chunk(chunk) for chunk in chunks
])
# Filter out None results and combine
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
# Generate audio using pre-warmed model
chunk_audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is None:
return None
# For streaming, convert to bytes
if output_format:
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=normalizer,
is_last_chunk=is_last,
stream=True
)
return chunk_audio
except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
return None
async def generate_audio_stream(
self,
@ -138,21 +89,12 @@ class TTSService:
voice: str,
speed: float = 1.0,
output_format: str = "wav",
):
"""Generate and stream audio chunks.
Args:
text: Input text
voice: Voice name
speed: Speed multiplier
output_format: Output audio format
Yields:
Audio chunks as bytes
"""
# Setup audio processing
) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
voice_tensor = None
pending_results = {}
next_index = 0
try:
# Normalize text
@ -161,11 +103,11 @@ class TTSService:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Get backend and load voice
# Get backend and load voice (should be fast if cached)
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get chunks using async generator
# Process chunks with semaphore limiting concurrency
chunks = []
async for chunk in chunker.split_text(text):
chunks.append(chunk)
@ -173,84 +115,80 @@ class TTSService:
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio
chunk_audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is not None:
# Convert to bytes
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=stream_normalizer,
is_last_chunk=is_last,
stream=True
)
except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
return None
# Create tasks for all chunks
tasks = [
process_chunk(chunk, i==0, i==len(chunks)-1)
asyncio.create_task(
self._process_chunk(
chunk,
voice_tensor,
speed,
output_format,
is_first=(i == 0),
is_last=(i == len(chunks) - 1),
normalizer=stream_normalizer
)
)
for i, chunk in enumerate(chunks)
]
# Process chunks concurrently and yield results in order
for chunk_bytes in await asyncio.gather(*tasks):
if chunk_bytes is not None:
yield chunk_bytes
# Process chunks and maintain order
for i, task in enumerate(tasks):
result = await task
if i == next_index and result is not None:
# If this is the next chunk we need, yield it
yield result
next_index += 1
# Check if we have any subsequent chunks ready
while next_index in pending_results:
result = pending_results.pop(next_index)
if result is not None:
yield result
next_index += 1
else:
# Store out-of-order result
pending_results[i] = result
except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}")
raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices.
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
chunks = []
Args:
voices: List of voice names
Returns:
Name of combined voice
"""
try:
# Use streaming generator but collect all chunks
async for chunk in self.generate_audio_stream(
text, voice, speed, output_format=None
):
if chunk is not None:
chunks.append(chunk)
if not chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices."""
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
"""List available voices.
Returns:
List of voice names
"""
return await self._voice_manager.list_voices()
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
"""Convert audio to WAV bytes.
Args:
audio: Audio samples
Returns:
WAV bytes
"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
"""List available voices."""
return await self._voice_manager.list_voices()

View file

@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""Benchmark script for unified streaming implementation"""
import asyncio
import time
from pathlib import Path
from typing import List, Tuple
from openai import OpenAI
import numpy as np
import matplotlib.pyplot as plt
# Initialize OpenAI client
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
TEST_TEXTS = {
"short": "The quick brown fox jumps over the lazy dog.",
"medium": """In a bustling city, life moves at a rapid pace.
People hurry along the sidewalks, while cars navigate
through the busy streets. The air is filled with the
sounds of urban activity.""",
"long": """The technological revolution has transformed how we live and work.
From artificial intelligence to renewable energy, innovations continue
to shape our future. As we face global challenges, scientific advances
offer new solutions. The intersection of technology and human creativity
drives progress forward, opening new possibilities for tomorrow."""
}
async def benchmark_streaming(text_name: str, text: str) -> Tuple[float, float, int]:
"""Benchmark streaming performance
Returns:
Tuple of (time to first byte, total time, total bytes)
"""
start_time = time.time()
total_bytes = 0
first_byte_time = None
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text,
) as response:
for chunk in response.iter_bytes(chunk_size=1024):
if first_byte_time is None:
first_byte_time = time.time() - start_time
total_bytes += len(chunk)
total_time = time.time() - start_time
return first_byte_time, total_time, total_bytes
async def benchmark_non_streaming(text_name: str, text: str) -> Tuple[float, int]:
"""Benchmark non-streaming performance
Returns:
Tuple of (total time, total bytes)
"""
start_time = time.time()
speech_file = Path(__file__).parent / f"non_stream_{text_name}.mp3"
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text,
) as response:
response.stream_to_file(speech_file)
total_time = time.time() - start_time
total_bytes = speech_file.stat().st_size
return total_time, total_bytes
def plot_results(results: dict):
"""Plot benchmark results"""
plt.figure(figsize=(12, 6))
# Prepare data
text_lengths = [len(text) for text in TEST_TEXTS.values()]
streaming_times = [r["streaming"]["total_time"] for r in results.values()]
non_streaming_times = [r["non_streaming"]["total_time"] for r in results.values()]
first_byte_times = [r["streaming"]["first_byte_time"] for r in results.values()]
# Plot times
x = np.arange(len(TEST_TEXTS))
width = 0.25
plt.bar(x - width, streaming_times, width, label='Streaming Total Time')
plt.bar(x, non_streaming_times, width, label='Non-Streaming Total Time')
plt.bar(x + width, first_byte_times, width, label='Time to First Byte')
plt.xlabel('Text Length (characters)')
plt.ylabel('Time (seconds)')
plt.title('Unified Streaming Performance Comparison')
plt.xticks(x, text_lengths)
plt.legend()
# Save plot
plt.savefig(Path(__file__).parent / 'benchmark_results.png')
plt.close()
async def main():
"""Run benchmarks"""
print("Starting unified streaming benchmarks...")
results = {}
for name, text in TEST_TEXTS.items():
print(f"\nTesting {name} text ({len(text)} chars)...")
# Test streaming
print("Running streaming test...")
first_byte_time, stream_total_time, stream_bytes = await benchmark_streaming(name, text)
# Test non-streaming
print("Running non-streaming test...")
non_stream_total_time, non_stream_bytes = await benchmark_non_streaming(name, text)
results[name] = {
"text_length": len(text),
"streaming": {
"first_byte_time": first_byte_time,
"total_time": stream_total_time,
"total_bytes": stream_bytes,
"throughput": stream_bytes / stream_total_time / 1024 # KB/s
},
"non_streaming": {
"total_time": non_stream_total_time,
"total_bytes": non_stream_bytes,
"throughput": non_stream_bytes / non_stream_total_time / 1024 # KB/s
}
}
# Print results for this test
print(f"\nResults for {name} text:")
print(f"Streaming:")
print(f" Time to first byte: {first_byte_time:.3f}s")
print(f" Total time: {stream_total_time:.3f}s")
print(f" Throughput: {stream_bytes/stream_total_time/1024:.1f} KB/s")
print(f"Non-streaming:")
print(f" Total time: {non_stream_total_time:.3f}s")
print(f" Throughput: {non_stream_bytes/non_stream_total_time/1024:.1f} KB/s")
# Plot results
plot_results(results)
print("\nBenchmark results have been plotted to benchmark_results.png")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,69 @@
#!/usr/bin/env python3
"""Test script for unified streaming implementation"""
import asyncio
import time
from pathlib import Path
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
async def test_streaming_to_file():
"""Test streaming to file"""
print("\nTesting streaming to file...")
speech_file = Path(__file__).parent / "stream_output.mp3"
start_time = time.time()
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input="Testing unified streaming implementation with a short phrase.",
) as response:
response.stream_to_file(speech_file)
print(f"Streaming to file completed in {(time.time() - start_time):.2f}s")
print(f"Output saved to: {speech_file}")
async def test_streaming_chunks():
"""Test streaming chunks for real-time playback"""
print("\nTesting chunk streaming...")
start_time = time.time()
chunk_count = 0
total_bytes = 0
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input="""This is a longer text to test chunk streaming.
We want to verify that the unified streaming implementation
works efficiently for both small and large inputs.""",
) as response:
print(f"Time to first byte: {(time.time() - start_time):.3f}s")
for chunk in response.iter_bytes(chunk_size=1024):
chunk_count += 1
total_bytes += len(chunk)
# In real usage, this would go to audio playback
# For testing, we just count chunks and bytes
total_time = time.time() - start_time
print(f"Received {chunk_count} chunks, {total_bytes} bytes")
print(f"Total streaming time: {total_time:.2f}s")
print(f"Average throughput: {total_bytes/total_time/1024:.1f} KB/s")
async def main():
"""Run all tests"""
print("Starting unified streaming tests...")
# Test both streaming modes
await test_streaming_to_file()
await test_streaming_chunks()
print("\nAll tests completed!")
if __name__ == "__main__":
asyncio.run(main())