Kokoro-FastAPI/examples/stream_tts_playback.py

138 lines
4.4 KiB
Python
Raw Normal View History

2025-01-04 17:54:54 -07:00
#!/usr/bin/env python3
import os
2025-01-09 18:41:44 -07:00
import time
2025-01-04 17:54:54 -07:00
import wave
2025-01-09 18:41:44 -07:00
import numpy as np
import requests
import sounddevice as sd
2025-01-04 17:54:54 -07:00
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
"""Stream TTS audio and play it back in real-time"""
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
print("\nStarting TTS stream request...")
start_time = time.time()
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Initialize variables
sample_rate = 24000 # Known sample rate for Kokoro
audio_started = False
chunk_count = 0
total_bytes = 0
first_chunk_time = None
all_audio_data = bytearray() # Raw PCM audio data
2025-01-09 18:41:44 -07:00
2025-01-04 17:55:36 -07:00
# Start sounddevice stream with buffer
stream = sd.OutputStream(
samplerate=sample_rate,
channels=1,
dtype=np.int16,
blocksize=1024, # Buffer size in samples
2025-01-09 18:41:44 -07:00
latency="low", # Request low latency
2025-01-04 17:55:36 -07:00
)
stream.start()
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Make streaming request to API
try:
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": voice,
"response_format": "pcm",
2025-01-09 18:41:44 -07:00
"stream": True,
2025-01-04 17:54:54 -07:00
},
stream=True,
2025-01-09 18:41:44 -07:00
timeout=1800,
2025-01-04 17:54:54 -07:00
)
response.raise_for_status()
print(f"Request started successfully after {time.time() - start_time:.2f}s")
2025-01-09 18:41:44 -07:00
2025-01-04 17:55:36 -07:00
# Process streaming response with smaller chunks for lower latency
2025-01-09 18:41:44 -07:00
for chunk in response.iter_content(
chunk_size=512
): # 512 bytes = 256 samples at 16-bit
2025-01-04 17:54:54 -07:00
if chunk:
chunk_count += 1
total_bytes += len(chunk)
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Handle first chunk
if not audio_started:
first_chunk_time = time.time()
2025-01-09 18:41:44 -07:00
print(
f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s"
)
2025-01-04 17:54:54 -07:00
print(f"First chunk size: {len(chunk)} bytes")
audio_started = True
2025-01-09 18:41:44 -07:00
2025-01-04 17:55:36 -07:00
# Convert bytes to numpy array and play
audio_chunk = np.frombuffer(chunk, dtype=np.int16)
stream.write(audio_chunk)
2025-01-09 18:41:44 -07:00
2025-01-04 17:55:36 -07:00
# Accumulate raw audio data
all_audio_data.extend(chunk)
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Log progress every 10 chunks
if chunk_count % 100 == 0:
2025-01-04 17:54:54 -07:00
elapsed = time.time() - start_time
2025-01-09 18:41:44 -07:00
print(
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"
)
2025-01-04 17:54:54 -07:00
# Final stats
total_time = time.time() - start_time
print(f"\nStream complete:")
print(f"Total chunks: {chunk_count}")
print(f"Total data: {total_bytes/1024:.1f}KB")
print(f"Total time: {total_time:.2f}s")
print(f"Average speed: {(total_bytes/1024)/total_time:.1f}KB/s")
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Save as WAV file
if output_file:
print(f"\nWriting audio to {output_file}")
2025-01-09 18:41:44 -07:00
with wave.open(output_file, "wb") as wav_file:
2025-01-04 17:54:54 -07:00
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit)
wav_file.setframerate(sample_rate)
wav_file.writeframes(all_audio_data)
print(f"Saved {len(all_audio_data)} bytes of audio data")
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
# Clean up
2025-01-04 17:55:36 -07:00
stream.stop()
stream.close()
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
except requests.exceptions.ConnectionError as e:
print(f"Connection error - Is the server running? Error: {str(e)}")
2025-01-04 17:55:36 -07:00
stream.stop()
stream.close()
2025-01-04 17:54:54 -07:00
except Exception as e:
print(f"Error during streaming: {str(e)}")
2025-01-04 17:55:36 -07:00
stream.stop()
stream.close()
2025-01-04 17:54:54 -07:00
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
def main():
# Load sample text from HG Wells
script_dir = os.path.dirname(os.path.abspath(__file__))
2025-01-09 18:41:44 -07:00
wells_path = os.path.join(
script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt"
)
2025-01-04 17:54:54 -07:00
output_path = os.path.join(script_dir, "output.wav")
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
with open(wells_path, "r", encoding="utf-8") as f:
full_text = f.read()
# Take first few paragraphs
text = " ".join(full_text.split("\n\n")[:2])
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
print("\nStarting TTS stream playback...")
print(f"Text length: {len(text)} characters")
print("\nFirst 100 characters:")
print(text[:100] + "...")
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
play_streaming_tts(text, output_file=output_path)
2025-01-09 18:41:44 -07:00
2025-01-04 17:54:54 -07:00
if __name__ == "__main__":
main()