mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
-unified streaming implementation
This commit is contained in:
parent
90c8f11111
commit
3547d95ee6
3 changed files with 328 additions and 173 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
import io
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional, AsyncGenerator, Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -25,134 +25,76 @@ class TTSService:
|
|||
_chunk_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
def __init__(self, output_dir: str = None):
|
||||
"""Initialize service.
|
||||
|
||||
Args:
|
||||
output_dir: Optional output directory for saving audio
|
||||
"""
|
||||
"""Initialize service."""
|
||||
self.output_dir = output_dir
|
||||
self.model_manager = None
|
||||
self._voice_manager = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, output_dir: str = None) -> 'TTSService':
|
||||
"""Create and initialize TTSService instance.
|
||||
|
||||
Args:
|
||||
output_dir: Optional output directory for saving audio
|
||||
|
||||
Returns:
|
||||
Initialized TTSService instance
|
||||
"""
|
||||
"""Create and initialize TTSService instance."""
|
||||
service = cls(output_dir)
|
||||
# Initialize managers
|
||||
service.model_manager = await get_model_manager()
|
||||
service._voice_manager = await get_voice_manager()
|
||||
return service
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio for text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
stitch_long_output: Whether to stitch together long outputs
|
||||
|
||||
Returns:
|
||||
Audio samples and processing time
|
||||
|
||||
Raises:
|
||||
ValueError: If text is empty after preprocessing or no chunks generated
|
||||
RuntimeError: If audio generation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
voice_tensor = None
|
||||
|
||||
try:
|
||||
# Normalize text
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get chunks using async generator
|
||||
chunks = []
|
||||
async for chunk in chunker.split_text(text):
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
# Process chunk with concurrency control
|
||||
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
|
||||
async def _process_chunk(
|
||||
self,
|
||||
chunk: str,
|
||||
voice_tensor: torch.Tensor,
|
||||
speed: float,
|
||||
output_format: Optional[str] = None,
|
||||
is_first: bool = False,
|
||||
is_last: bool = False,
|
||||
normalizer: Optional[AudioNormalizer] = None,
|
||||
) -> Optional[Union[np.ndarray, bytes]]:
|
||||
"""Process a single text chunk into audio."""
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
tokens = process_text(chunk)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Generate audio
|
||||
return await self.model_manager.generate(
|
||||
# Generate audio using pre-warmed model
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
|
||||
if chunk_audio is None:
|
||||
return None
|
||||
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
return await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
stream=True
|
||||
)
|
||||
|
||||
return chunk_audio
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||
return None
|
||||
|
||||
# Process all chunks concurrently
|
||||
chunk_results = await asyncio.gather(*[
|
||||
process_chunk(chunk) for chunk in chunks
|
||||
])
|
||||
|
||||
# Filter out None results and combine
|
||||
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Always clean up voice tensor
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def generate_audio_stream(
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
):
|
||||
"""Generate and stream audio chunks.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
output_format: Output audio format
|
||||
|
||||
Yields:
|
||||
Audio chunks as bytes
|
||||
"""
|
||||
# Setup audio processing
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
voice_tensor = None
|
||||
pending_results = {}
|
||||
next_index = 0
|
||||
|
||||
try:
|
||||
# Normalize text
|
||||
|
@ -161,11 +103,11 @@ class TTSService:
|
|||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Get backend and load voice
|
||||
# Get backend and load voice (should be fast if cached)
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get chunks using async generator
|
||||
# Process chunks with semaphore limiting concurrency
|
||||
chunks = []
|
||||
async for chunk in chunker.split_text(text):
|
||||
chunks.append(chunk)
|
||||
|
@ -173,84 +115,80 @@ class TTSService:
|
|||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
# Process chunk with concurrency control
|
||||
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
tokens = process_text(chunk)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Generate audio
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
|
||||
if chunk_audio is not None:
|
||||
# Convert to bytes
|
||||
return await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=is_last,
|
||||
stream=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
|
||||
return None
|
||||
|
||||
# Create tasks for all chunks
|
||||
tasks = [
|
||||
process_chunk(chunk, i==0, i==len(chunks)-1)
|
||||
asyncio.create_task(
|
||||
self._process_chunk(
|
||||
chunk,
|
||||
voice_tensor,
|
||||
speed,
|
||||
output_format,
|
||||
is_first=(i == 0),
|
||||
is_last=(i == len(chunks) - 1),
|
||||
normalizer=stream_normalizer
|
||||
)
|
||||
)
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
# Process chunks concurrently and yield results in order
|
||||
for chunk_bytes in await asyncio.gather(*tasks):
|
||||
if chunk_bytes is not None:
|
||||
yield chunk_bytes
|
||||
# Process chunks and maintain order
|
||||
for i, task in enumerate(tasks):
|
||||
result = await task
|
||||
|
||||
if i == next_index and result is not None:
|
||||
# If this is the next chunk we need, yield it
|
||||
yield result
|
||||
next_index += 1
|
||||
|
||||
# Check if we have any subsequent chunks ready
|
||||
while next_index in pending_results:
|
||||
result = pending_results.pop(next_index)
|
||||
if result is not None:
|
||||
yield result
|
||||
next_index += 1
|
||||
else:
|
||||
# Store out-of-order result
|
||||
pending_results[i] = result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Always clean up voice tensor
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
# Use streaming generator but collect all chunks
|
||||
async for chunk in self.generate_audio_stream(
|
||||
text, voice, speed, output_format=None
|
||||
):
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices.
|
||||
|
||||
Args:
|
||||
voices: List of voice names
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
"""
|
||||
"""Combine multiple voices."""
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices.
|
||||
|
||||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
"""List available voices."""
|
||||
return await self._voice_manager.list_voices()
|
||||
|
||||
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
||||
"""Convert audio to WAV bytes.
|
||||
|
||||
Args:
|
||||
audio: Audio samples
|
||||
|
||||
Returns:
|
||||
WAV bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
148
examples/streaming_refactor/benchmark_unified_streaming.py
Normal file
148
examples/streaming_refactor/benchmark_unified_streaming.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Benchmark script for unified streaming implementation"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||
|
||||
TEST_TEXTS = {
|
||||
"short": "The quick brown fox jumps over the lazy dog.",
|
||||
"medium": """In a bustling city, life moves at a rapid pace.
|
||||
People hurry along the sidewalks, while cars navigate
|
||||
through the busy streets. The air is filled with the
|
||||
sounds of urban activity.""",
|
||||
"long": """The technological revolution has transformed how we live and work.
|
||||
From artificial intelligence to renewable energy, innovations continue
|
||||
to shape our future. As we face global challenges, scientific advances
|
||||
offer new solutions. The intersection of technology and human creativity
|
||||
drives progress forward, opening new possibilities for tomorrow."""
|
||||
}
|
||||
|
||||
async def benchmark_streaming(text_name: str, text: str) -> Tuple[float, float, int]:
|
||||
"""Benchmark streaming performance
|
||||
|
||||
Returns:
|
||||
Tuple of (time to first byte, total time, total bytes)
|
||||
"""
|
||||
start_time = time.time()
|
||||
total_bytes = 0
|
||||
first_byte_time = None
|
||||
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
response_format="pcm",
|
||||
input=text,
|
||||
) as response:
|
||||
for chunk in response.iter_bytes(chunk_size=1024):
|
||||
if first_byte_time is None:
|
||||
first_byte_time = time.time() - start_time
|
||||
total_bytes += len(chunk)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
return first_byte_time, total_time, total_bytes
|
||||
|
||||
async def benchmark_non_streaming(text_name: str, text: str) -> Tuple[float, int]:
|
||||
"""Benchmark non-streaming performance
|
||||
|
||||
Returns:
|
||||
Tuple of (total time, total bytes)
|
||||
"""
|
||||
start_time = time.time()
|
||||
speech_file = Path(__file__).parent / f"non_stream_{text_name}.mp3"
|
||||
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
input=text,
|
||||
) as response:
|
||||
response.stream_to_file(speech_file)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_bytes = speech_file.stat().st_size
|
||||
return total_time, total_bytes
|
||||
|
||||
def plot_results(results: dict):
|
||||
"""Plot benchmark results"""
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
# Prepare data
|
||||
text_lengths = [len(text) for text in TEST_TEXTS.values()]
|
||||
streaming_times = [r["streaming"]["total_time"] for r in results.values()]
|
||||
non_streaming_times = [r["non_streaming"]["total_time"] for r in results.values()]
|
||||
first_byte_times = [r["streaming"]["first_byte_time"] for r in results.values()]
|
||||
|
||||
# Plot times
|
||||
x = np.arange(len(TEST_TEXTS))
|
||||
width = 0.25
|
||||
|
||||
plt.bar(x - width, streaming_times, width, label='Streaming Total Time')
|
||||
plt.bar(x, non_streaming_times, width, label='Non-Streaming Total Time')
|
||||
plt.bar(x + width, first_byte_times, width, label='Time to First Byte')
|
||||
|
||||
plt.xlabel('Text Length (characters)')
|
||||
plt.ylabel('Time (seconds)')
|
||||
plt.title('Unified Streaming Performance Comparison')
|
||||
plt.xticks(x, text_lengths)
|
||||
plt.legend()
|
||||
|
||||
# Save plot
|
||||
plt.savefig(Path(__file__).parent / 'benchmark_results.png')
|
||||
plt.close()
|
||||
|
||||
async def main():
|
||||
"""Run benchmarks"""
|
||||
print("Starting unified streaming benchmarks...")
|
||||
|
||||
results = {}
|
||||
|
||||
for name, text in TEST_TEXTS.items():
|
||||
print(f"\nTesting {name} text ({len(text)} chars)...")
|
||||
|
||||
# Test streaming
|
||||
print("Running streaming test...")
|
||||
first_byte_time, stream_total_time, stream_bytes = await benchmark_streaming(name, text)
|
||||
|
||||
# Test non-streaming
|
||||
print("Running non-streaming test...")
|
||||
non_stream_total_time, non_stream_bytes = await benchmark_non_streaming(name, text)
|
||||
|
||||
results[name] = {
|
||||
"text_length": len(text),
|
||||
"streaming": {
|
||||
"first_byte_time": first_byte_time,
|
||||
"total_time": stream_total_time,
|
||||
"total_bytes": stream_bytes,
|
||||
"throughput": stream_bytes / stream_total_time / 1024 # KB/s
|
||||
},
|
||||
"non_streaming": {
|
||||
"total_time": non_stream_total_time,
|
||||
"total_bytes": non_stream_bytes,
|
||||
"throughput": non_stream_bytes / non_stream_total_time / 1024 # KB/s
|
||||
}
|
||||
}
|
||||
|
||||
# Print results for this test
|
||||
print(f"\nResults for {name} text:")
|
||||
print(f"Streaming:")
|
||||
print(f" Time to first byte: {first_byte_time:.3f}s")
|
||||
print(f" Total time: {stream_total_time:.3f}s")
|
||||
print(f" Throughput: {stream_bytes/stream_total_time/1024:.1f} KB/s")
|
||||
print(f"Non-streaming:")
|
||||
print(f" Total time: {non_stream_total_time:.3f}s")
|
||||
print(f" Throughput: {non_stream_bytes/non_stream_total_time/1024:.1f} KB/s")
|
||||
|
||||
# Plot results
|
||||
plot_results(results)
|
||||
print("\nBenchmark results have been plotted to benchmark_results.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
69
examples/streaming_refactor/test_unified_streaming.py
Normal file
69
examples/streaming_refactor/test_unified_streaming.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test script for unified streaming implementation"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||
|
||||
async def test_streaming_to_file():
|
||||
"""Test streaming to file"""
|
||||
print("\nTesting streaming to file...")
|
||||
speech_file = Path(__file__).parent / "stream_output.mp3"
|
||||
|
||||
start_time = time.time()
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
input="Testing unified streaming implementation with a short phrase.",
|
||||
) as response:
|
||||
response.stream_to_file(speech_file)
|
||||
|
||||
print(f"Streaming to file completed in {(time.time() - start_time):.2f}s")
|
||||
print(f"Output saved to: {speech_file}")
|
||||
|
||||
async def test_streaming_chunks():
|
||||
"""Test streaming chunks for real-time playback"""
|
||||
print("\nTesting chunk streaming...")
|
||||
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
response_format="pcm",
|
||||
input="""This is a longer text to test chunk streaming.
|
||||
We want to verify that the unified streaming implementation
|
||||
works efficiently for both small and large inputs.""",
|
||||
) as response:
|
||||
print(f"Time to first byte: {(time.time() - start_time):.3f}s")
|
||||
|
||||
for chunk in response.iter_bytes(chunk_size=1024):
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
# In real usage, this would go to audio playback
|
||||
# For testing, we just count chunks and bytes
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"Received {chunk_count} chunks, {total_bytes} bytes")
|
||||
print(f"Total streaming time: {total_time:.2f}s")
|
||||
print(f"Average throughput: {total_bytes/total_time/1024:.1f} KB/s")
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
print("Starting unified streaming tests...")
|
||||
|
||||
# Test both streaming modes
|
||||
await test_streaming_to_file()
|
||||
await test_streaming_chunks()
|
||||
|
||||
print("\nAll tests completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Add table
Reference in a new issue