mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
WIP: open ai compatible streaming
This commit is contained in:
parent
f1eb1d9590
commit
0e9f77fc79
10 changed files with 137 additions and 102 deletions
BIN
.coverage
BIN
.coverage
Binary file not shown.
|
@ -23,16 +23,25 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Initialize the main model with warm-up
|
||||
voicepack_count = TTSModel.setup()
|
||||
logger.info("""
|
||||
███████╗█████╗█████████████████╗ ██╗██████╗██╗ ██╗██████╗
|
||||
██╔════██╔══████╔════╚══██╔══██║ ██╔██╔═══████║ ██╔██╔═══██╗
|
||||
█████╗ ██████████████╗ ██║ █████╔╝██║ ███████╔╝██║ ██║
|
||||
██╔══╝ ██╔══██╚════██║ ██║ ██╔═██╗██║ ████╔═██╗██║ ██║
|
||||
██║ ██║ █████████║ ██║ ██║ ██╚██████╔██║ ██╚██████╔╝
|
||||
╚═╝ ╚═╝ ╚═╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═════╝╚═╝ ╚═╝╚═════╝ """)
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}")
|
||||
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||
logger.info("#" * 80)
|
||||
# boundary = "█████╗"*9
|
||||
boundary = "░" * 54
|
||||
startup_msg =f"""
|
||||
{boundary}
|
||||
|
||||
╔═╗┌─┐┌─┐┌┬┐
|
||||
╠╣ ├─┤└─┐ │
|
||||
╚ ┴ ┴└─┘ ┴
|
||||
╦╔═┌─┐┬┌─┌─┐
|
||||
╠╩╗│ │├┴┐│ │
|
||||
╩ ╩└─┘┴ ┴└─┘
|
||||
|
||||
{boundary}
|
||||
"""
|
||||
startup_msg += f"\nModel loaded and warmed up on {TTSModel.get_device()}"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded successfully\n"
|
||||
startup_msg += f"\n{boundary}\n"
|
||||
logger.info(startup_msg)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@ from typing import List
|
|||
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
from fastapi import Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
@ -30,9 +30,13 @@ async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequ
|
|||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service)
|
||||
request: OpenAISpeechRequest,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
try:
|
||||
|
@ -53,7 +57,10 @@ async def create_speech(
|
|||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
if request.stream:
|
||||
# Check if streaming is requested via header
|
||||
is_streaming = x_raw_response == "stream"
|
||||
|
||||
if is_streaming:
|
||||
# Stream audio chunks as they're generated
|
||||
return StreamingResponse(
|
||||
stream_audio_chunks(tts_service, request),
|
||||
|
|
|
@ -14,17 +14,15 @@ class AudioNormalizer:
|
|||
|
||||
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||
"""Normalize audio data to int16 range"""
|
||||
# Convert to float64 for accurate scaling
|
||||
audio_float = audio_data.astype(np.float64)
|
||||
# Convert to float32 if not already
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
||||
# Scale to int16 range while preserving relative amplitudes
|
||||
max_val = np.abs(audio_float).max()
|
||||
if max_val > 0:
|
||||
scaling = self.int16_max / max_val
|
||||
audio_float *= scaling
|
||||
|
||||
# Clip to int16 range and convert
|
||||
return np.clip(audio_float, -self.int16_max, self.int16_max).astype(np.int16)
|
||||
# Normalize to [-1, 1] range first
|
||||
if np.max(np.abs(audio_float)) > 0:
|
||||
audio_float = audio_float / np.max(np.abs(audio_float))
|
||||
|
||||
# Scale to int16 range
|
||||
return (audio_float * self.int16_max).astype(np.int16)
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions"""
|
||||
|
@ -51,11 +49,10 @@ class AudioService:
|
|||
buffer = BytesIO()
|
||||
|
||||
try:
|
||||
# Normalize audio if normalizer provided, otherwise just convert to int16
|
||||
if normalizer is not None:
|
||||
normalized_audio = normalizer.normalize(audio_data)
|
||||
else:
|
||||
normalized_audio = audio_data.astype(np.int16)
|
||||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(audio_data)
|
||||
|
||||
if output_format == "pcm":
|
||||
logger.info("Writing PCM data...")
|
||||
|
@ -68,8 +65,7 @@ class AudioService:
|
|||
elif output_format in ["mp3", "aac"]:
|
||||
logger.info(f"Converting to {output_format.upper()} format...")
|
||||
# Use lower bitrate for streaming
|
||||
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper(),
|
||||
subtype='COMPRESSED')
|
||||
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper())
|
||||
elif output_format == "opus":
|
||||
logger.info("Converting to Opus format...")
|
||||
# Use lower bitrate and smaller frame size for streaming
|
||||
|
|
|
@ -132,24 +132,8 @@ class TTSService:
|
|||
raise ValueError(f"Voice not found: {voice}")
|
||||
voicepack = self._load_voice(voice_path)
|
||||
|
||||
# Split text into smaller chunks for faster streaming
|
||||
# Use shorter chunks for real-time delivery
|
||||
chunks = []
|
||||
sentences = self._split_text(text)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
target_length = 100 # Target ~100 characters per chunk for faster processing
|
||||
|
||||
for sentence in sentences:
|
||||
current_chunk.append(sentence)
|
||||
current_length += len(sentence)
|
||||
if current_length >= target_length:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
# Split text into sentences for natural boundaries
|
||||
chunks = self._split_text(text)
|
||||
|
||||
# Process and stream chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
|
|
@ -1,20 +1,26 @@
|
|||
services:
|
||||
model-fetcher:
|
||||
image: datamachines/git-lfs:latest
|
||||
environment:
|
||||
- SKIP_MODEL_FETCH=${SKIP_MODEL_FETCH:-true}
|
||||
volumes:
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
working_dir: /app/Kokoro-82M
|
||||
command: >
|
||||
sh -c "
|
||||
rm -f .git/index.lock;
|
||||
if [ -z \"$(ls -A .)\" ]; then
|
||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||
touch .cloned;
|
||||
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
||||
echo 'Skipping model fetch...' && touch .cloned;
|
||||
else
|
||||
rm -f .git/index.lock && \
|
||||
git checkout main && \
|
||||
git pull origin main && \
|
||||
touch .cloned;
|
||||
rm -f .git/index.lock;
|
||||
if [ -z \"$(ls -A .)\" ]; then
|
||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||
touch .cloned;
|
||||
else
|
||||
rm -f .git/index.lock && \
|
||||
git checkout main && \
|
||||
git pull origin main && \
|
||||
touch .cloned;
|
||||
fi;
|
||||
fi;
|
||||
tail -f /dev/null
|
||||
"
|
||||
|
|
52
examples/openai_streaming_audio.py
Normal file
52
examples/openai_streaming_audio.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
|
||||
#!/usr/bin/env rye run python
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# gets OPENAI_API_KEY from your environment variables
|
||||
openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
stream_to_speakers()
|
||||
|
||||
# Create text-to-speech audio file
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
) as response:
|
||||
response.stream_to_file(speech_file_path)
|
||||
|
||||
|
||||
|
||||
def stream_to_speakers() -> None:
|
||||
import pyaudio
|
||||
|
||||
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af",
|
||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||
input="""I see skies of blue and clouds of white
|
||||
The bright blessed days, the dark sacred nights
|
||||
And I think to myself
|
||||
What a wonderful world""",
|
||||
) as response:
|
||||
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||
for chunk in response.iter_bytes(chunk_size=1024):
|
||||
player_stream.write(chunk)
|
||||
|
||||
print(f"Done in {int((time.time() - start_time) * 1000)}ms.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
BIN
examples/speech.mp3
Normal file
BIN
examples/speech.mp3
Normal file
Binary file not shown.
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
import requests
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import time
|
||||
import os
|
||||
import wave
|
||||
|
@ -15,12 +15,21 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
# Initialize variables
|
||||
sample_rate = 24000 # Known sample rate for Kokoro
|
||||
audio_started = False
|
||||
stream = None
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
first_chunk_time = None
|
||||
all_audio_data = bytearray() # Raw PCM audio data
|
||||
|
||||
# Start sounddevice stream with buffer
|
||||
stream = sd.OutputStream(
|
||||
samplerate=sample_rate,
|
||||
channels=1,
|
||||
dtype=np.int16,
|
||||
blocksize=1024, # Buffer size in samples
|
||||
latency='low' # Request low latency
|
||||
)
|
||||
stream.start()
|
||||
|
||||
# Make streaming request to API
|
||||
try:
|
||||
response = requests.post(
|
||||
|
@ -38,8 +47,8 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
response.raise_for_status()
|
||||
print(f"Request started successfully after {time.time() - start_time:.2f}s")
|
||||
|
||||
# Process streaming response
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
# Process streaming response with smaller chunks for lower latency
|
||||
for chunk in response.iter_content(chunk_size=512): # 512 bytes = 256 samples at 16-bit
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
|
@ -49,40 +58,15 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
first_chunk_time = time.time()
|
||||
print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s")
|
||||
print(f"First chunk size: {len(chunk)} bytes")
|
||||
|
||||
# Accumulate raw audio data
|
||||
all_audio_data.extend(chunk)
|
||||
|
||||
# Convert PCM to float32 for playback
|
||||
audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32)
|
||||
# Scale to [-1, 1] range for sounddevice
|
||||
audio_data = audio_data / 32768.0
|
||||
|
||||
# Start audio stream
|
||||
stream = sd.OutputStream(
|
||||
samplerate=sample_rate,
|
||||
channels=1,
|
||||
dtype=np.float32
|
||||
)
|
||||
stream.start()
|
||||
audio_started = True
|
||||
print("Audio playback started")
|
||||
|
||||
# Play first chunk
|
||||
if len(audio_data) > 0:
|
||||
stream.write(audio_data)
|
||||
|
||||
# Handle subsequent chunks
|
||||
else:
|
||||
# Accumulate raw audio data
|
||||
all_audio_data.extend(chunk)
|
||||
|
||||
# Convert PCM to float32 for playback
|
||||
audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32)
|
||||
audio_data = audio_data / 32768.0
|
||||
if len(audio_data) > 0:
|
||||
stream.write(audio_data)
|
||||
|
||||
|
||||
# Convert bytes to numpy array and play
|
||||
audio_chunk = np.frombuffer(chunk, dtype=np.int16)
|
||||
stream.write(audio_chunk)
|
||||
|
||||
# Accumulate raw audio data
|
||||
all_audio_data.extend(chunk)
|
||||
|
||||
# Log progress every 10 chunks
|
||||
if chunk_count % 10 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
|
@ -107,20 +91,17 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
print(f"Saved {len(all_audio_data)} bytes of audio data")
|
||||
|
||||
# Clean up
|
||||
if stream is not None:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
stream.stop()
|
||||
stream.close()
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f"Connection error - Is the server running? Error: {str(e)}")
|
||||
if stream is not None:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
stream.stop()
|
||||
stream.close()
|
||||
except Exception as e:
|
||||
print(f"Error during streaming: {str(e)}")
|
||||
if stream is not None:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
stream.stop()
|
||||
stream.close()
|
||||
|
||||
def main():
|
||||
# Load sample text from HG Wells
|
||||
|
|
Loading…
Add table
Reference in a new issue