WIP: open ai compatible streaming

This commit is contained in:
remsky 2025-01-04 17:55:36 -07:00
parent f1eb1d9590
commit 0e9f77fc79
10 changed files with 137 additions and 102 deletions

BIN
.coverage

Binary file not shown.

View file

@ -23,16 +23,25 @@ async def lifespan(app: FastAPI):
# Initialize the main model with warm-up # Initialize the main model with warm-up
voicepack_count = TTSModel.setup() voicepack_count = TTSModel.setup()
logger.info(""" # boundary = "█████╗"*9
boundary = "" * 54
startup_msg =f"""
{boundary}
""")
logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}")
logger.info(f"{voicepack_count} voice packs loaded successfully")
logger.info("#" * 80)
{boundary}
"""
startup_msg += f"\nModel loaded and warmed up on {TTSModel.get_device()}"
startup_msg += f"\n{voicepack_count} voice packs loaded successfully\n"
startup_msg += f"\n{boundary}\n"
logger.info(startup_msg)
yield yield

View file

@ -2,8 +2,8 @@ from typing import List
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException from fastapi import Depends, Response, APIRouter, HTTPException
from fastapi import Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from ..services.tts_service import TTSService from ..services.tts_service import TTSService
from ..services.audio import AudioService from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
@ -30,9 +30,13 @@ async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequ
): ):
yield chunk yield chunk
@router.post("/audio/speech") @router.post("/audio/speech")
async def create_speech( async def create_speech(
request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service) request: OpenAISpeechRequest,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
): ):
"""OpenAI-compatible endpoint for text-to-speech""" """OpenAI-compatible endpoint for text-to-speech"""
try: try:
@ -53,7 +57,10 @@ async def create_speech(
"pcm": "audio/pcm", "pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}") }.get(request.response_format, f"audio/{request.response_format}")
if request.stream: # Check if streaming is requested via header
is_streaming = x_raw_response == "stream"
if is_streaming:
# Stream audio chunks as they're generated # Stream audio chunks as they're generated
return StreamingResponse( return StreamingResponse(
stream_audio_chunks(tts_service, request), stream_audio_chunks(tts_service, request),

View file

@ -14,17 +14,15 @@ class AudioNormalizer:
def normalize(self, audio_data: np.ndarray) -> np.ndarray: def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Normalize audio data to int16 range""" """Normalize audio data to int16 range"""
# Convert to float64 for accurate scaling # Convert to float32 if not already
audio_float = audio_data.astype(np.float64) audio_float = audio_data.astype(np.float32)
# Scale to int16 range while preserving relative amplitudes # Normalize to [-1, 1] range first
max_val = np.abs(audio_float).max() if np.max(np.abs(audio_float)) > 0:
if max_val > 0: audio_float = audio_float / np.max(np.abs(audio_float))
scaling = self.int16_max / max_val
audio_float *= scaling # Scale to int16 range
return (audio_float * self.int16_max).astype(np.int16)
# Clip to int16 range and convert
return np.clip(audio_float, -self.int16_max, self.int16_max).astype(np.int16)
class AudioService: class AudioService:
"""Service for audio format conversions""" """Service for audio format conversions"""
@ -51,11 +49,10 @@ class AudioService:
buffer = BytesIO() buffer = BytesIO()
try: try:
# Normalize audio if normalizer provided, otherwise just convert to int16 # Always normalize audio to ensure proper amplitude scaling
if normalizer is not None: if normalizer is None:
normalized_audio = normalizer.normalize(audio_data) normalizer = AudioNormalizer()
else: normalized_audio = normalizer.normalize(audio_data)
normalized_audio = audio_data.astype(np.int16)
if output_format == "pcm": if output_format == "pcm":
logger.info("Writing PCM data...") logger.info("Writing PCM data...")
@ -68,8 +65,7 @@ class AudioService:
elif output_format in ["mp3", "aac"]: elif output_format in ["mp3", "aac"]:
logger.info(f"Converting to {output_format.upper()} format...") logger.info(f"Converting to {output_format.upper()} format...")
# Use lower bitrate for streaming # Use lower bitrate for streaming
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper(), sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper())
subtype='COMPRESSED')
elif output_format == "opus": elif output_format == "opus":
logger.info("Converting to Opus format...") logger.info("Converting to Opus format...")
# Use lower bitrate and smaller frame size for streaming # Use lower bitrate and smaller frame size for streaming

View file

@ -132,24 +132,8 @@ class TTSService:
raise ValueError(f"Voice not found: {voice}") raise ValueError(f"Voice not found: {voice}")
voicepack = self._load_voice(voice_path) voicepack = self._load_voice(voice_path)
# Split text into smaller chunks for faster streaming # Split text into sentences for natural boundaries
# Use shorter chunks for real-time delivery chunks = self._split_text(text)
chunks = []
sentences = self._split_text(text)
current_chunk = []
current_length = 0
target_length = 100 # Target ~100 characters per chunk for faster processing
for sentence in sentences:
current_chunk.append(sentence)
current_length += len(sentence)
if current_length >= target_length:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
if current_chunk:
chunks.append(" ".join(current_chunk))
# Process and stream chunks # Process and stream chunks
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):

View file

@ -1,20 +1,26 @@
services: services:
model-fetcher: model-fetcher:
image: datamachines/git-lfs:latest image: datamachines/git-lfs:latest
environment:
- SKIP_MODEL_FETCH=${SKIP_MODEL_FETCH:-true}
volumes: volumes:
- ./Kokoro-82M:/app/Kokoro-82M - ./Kokoro-82M:/app/Kokoro-82M
working_dir: /app/Kokoro-82M working_dir: /app/Kokoro-82M
command: > command: >
sh -c " sh -c "
rm -f .git/index.lock; if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
if [ -z \"$(ls -A .)\" ]; then echo 'Skipping model fetch...' && touch .cloned;
git clone https://huggingface.co/hexgrad/Kokoro-82M .
touch .cloned;
else else
rm -f .git/index.lock && \ rm -f .git/index.lock;
git checkout main && \ if [ -z \"$(ls -A .)\" ]; then
git pull origin main && \ git clone https://huggingface.co/hexgrad/Kokoro-82M .
touch .cloned; touch .cloned;
else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned;
fi;
fi; fi;
tail -f /dev/null tail -f /dev/null
" "

View file

@ -0,0 +1,52 @@
#!/usr/bin/env rye run python
import time
from pathlib import Path
from openai import OpenAI
# gets OPENAI_API_KEY from your environment variables
openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
speech_file_path = Path(__file__).parent / "speech.mp3"
def main() -> None:
stream_to_speakers()
# Create text-to-speech audio file
with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af",
input="the quick brown fox jumped over the lazy dogs",
) as response:
response.stream_to_file(speech_file_path)
def stream_to_speakers() -> None:
import pyaudio
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
start_time = time.time()
with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af",
response_format="pcm", # similar to WAV, but without a header chunk at the start.
input="""I see skies of blue and clouds of white
The bright blessed days, the dark sacred nights
And I think to myself
What a wonderful world""",
) as response:
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
for chunk in response.iter_bytes(chunk_size=1024):
player_stream.write(chunk)
print(f"Done in {int((time.time() - start_time) * 1000)}ms.")
if __name__ == "__main__":
main()

Binary file not shown.

BIN
examples/speech.mp3 Normal file

Binary file not shown.

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import requests import requests
import sounddevice as sd
import numpy as np import numpy as np
import sounddevice as sd
import time import time
import os import os
import wave import wave
@ -15,12 +15,21 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
# Initialize variables # Initialize variables
sample_rate = 24000 # Known sample rate for Kokoro sample_rate = 24000 # Known sample rate for Kokoro
audio_started = False audio_started = False
stream = None
chunk_count = 0 chunk_count = 0
total_bytes = 0 total_bytes = 0
first_chunk_time = None first_chunk_time = None
all_audio_data = bytearray() # Raw PCM audio data all_audio_data = bytearray() # Raw PCM audio data
# Start sounddevice stream with buffer
stream = sd.OutputStream(
samplerate=sample_rate,
channels=1,
dtype=np.int16,
blocksize=1024, # Buffer size in samples
latency='low' # Request low latency
)
stream.start()
# Make streaming request to API # Make streaming request to API
try: try:
response = requests.post( response = requests.post(
@ -38,8 +47,8 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
response.raise_for_status() response.raise_for_status()
print(f"Request started successfully after {time.time() - start_time:.2f}s") print(f"Request started successfully after {time.time() - start_time:.2f}s")
# Process streaming response # Process streaming response with smaller chunks for lower latency
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=512): # 512 bytes = 256 samples at 16-bit
if chunk: if chunk:
chunk_count += 1 chunk_count += 1
total_bytes += len(chunk) total_bytes += len(chunk)
@ -49,40 +58,15 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
first_chunk_time = time.time() first_chunk_time = time.time()
print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s") print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s")
print(f"First chunk size: {len(chunk)} bytes") print(f"First chunk size: {len(chunk)} bytes")
# Accumulate raw audio data
all_audio_data.extend(chunk)
# Convert PCM to float32 for playback
audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32)
# Scale to [-1, 1] range for sounddevice
audio_data = audio_data / 32768.0
# Start audio stream
stream = sd.OutputStream(
samplerate=sample_rate,
channels=1,
dtype=np.float32
)
stream.start()
audio_started = True audio_started = True
print("Audio playback started")
# Convert bytes to numpy array and play
# Play first chunk audio_chunk = np.frombuffer(chunk, dtype=np.int16)
if len(audio_data) > 0: stream.write(audio_chunk)
stream.write(audio_data)
# Accumulate raw audio data
# Handle subsequent chunks all_audio_data.extend(chunk)
else:
# Accumulate raw audio data
all_audio_data.extend(chunk)
# Convert PCM to float32 for playback
audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32)
audio_data = audio_data / 32768.0
if len(audio_data) > 0:
stream.write(audio_data)
# Log progress every 10 chunks # Log progress every 10 chunks
if chunk_count % 10 == 0: if chunk_count % 10 == 0:
elapsed = time.time() - start_time elapsed = time.time() - start_time
@ -107,20 +91,17 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
print(f"Saved {len(all_audio_data)} bytes of audio data") print(f"Saved {len(all_audio_data)} bytes of audio data")
# Clean up # Clean up
if stream is not None: stream.stop()
stream.stop() stream.close()
stream.close()
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(f"Connection error - Is the server running? Error: {str(e)}") print(f"Connection error - Is the server running? Error: {str(e)}")
if stream is not None: stream.stop()
stream.stop() stream.close()
stream.close()
except Exception as e: except Exception as e:
print(f"Error during streaming: {str(e)}") print(f"Error during streaming: {str(e)}")
if stream is not None: stream.stop()
stream.stop() stream.close()
stream.close()
def main(): def main():
# Load sample text from HG Wells # Load sample text from HG Wells