-fix voice selection not matching language phonemes

-added voice language override parameter
This commit is contained in:
remsky 2025-02-08 01:29:15 -07:00
parent 68cc14896a
commit a0dc870f4a
12 changed files with 259 additions and 51 deletions

View file

@ -3,8 +3,8 @@
</p>
# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-63%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-53%25-tan)]()
[![Tests](https://img.shields.io/badge/tests-66%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]()
[![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)

View file

@ -1,7 +1,7 @@
"""Clean Kokoro implementation with controlled resource management."""
import os
from typing import AsyncGenerator, Optional, Union, Tuple
from typing import AsyncGenerator, Optional, Union, Tuple, Dict
import numpy as np
import torch
@ -22,7 +22,7 @@ class KokoroV1(BaseModelBackend):
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._model: Optional[KModel] = None
self._pipeline: Optional[KPipeline] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
async def load_model(self, path: str) -> None:
"""Load pre-baked model.
@ -54,22 +54,38 @@ class KokoroV1(BaseModelBackend):
if self._device == "cuda":
self._model = self._model.cuda()
# Initialize pipeline with our model and device
self._pipeline = KPipeline(
lang_code='a',
model=self._model, # Pass our model directly
device=self._device # Match our device setting
)
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}")
def _get_pipeline(self, lang_code: str) -> KPipeline:
"""Get or create pipeline for language code.
Args:
lang_code: Language code to use
Returns:
KPipeline instance for the language
"""
if not self._model:
raise RuntimeError("Model not loaded")
if lang_code not in self._pipelines:
logger.info(f"Creating new pipeline for language code: {lang_code}")
self._pipelines[lang_code] = KPipeline(
lang_code=lang_code,
model=self._model,
device=self._device
)
return self._pipelines[lang_code]
async def generate_from_tokens(
self,
tokens: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0
speed: float = 1.0,
lang_code: Optional[str] = None
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from phoneme tokens.
@ -77,6 +93,7 @@ class KokoroV1(BaseModelBackend):
tokens: Input phoneme tokens to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
@ -95,6 +112,7 @@ class KokoroV1(BaseModelBackend):
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
@ -108,6 +126,7 @@ class KokoroV1(BaseModelBackend):
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
@ -118,9 +137,12 @@ class KokoroV1(BaseModelBackend):
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Generate using pipeline's generate_from_tokens method
logger.debug(f"Generating audio from tokens: '{tokens[:100]}...'")
for result in self._pipeline.generate_from_tokens(
# Use provided lang_code or get from voice name
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'")
for result in pipeline.generate_from_tokens(
tokens=tokens,
voice=voice_path,
speed=speed,
@ -140,7 +162,7 @@ class KokoroV1(BaseModelBackend):
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate_from_tokens(tokens, voice, speed):
async for chunk in self.generate_from_tokens(tokens, voice, speed, lang_code):
yield chunk
raise
@ -148,7 +170,8 @@ class KokoroV1(BaseModelBackend):
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0
speed: float = 1.0,
lang_code: Optional[str] = None
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio using model.
@ -156,6 +179,7 @@ class KokoroV1(BaseModelBackend):
text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
@ -174,6 +198,7 @@ class KokoroV1(BaseModelBackend):
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
@ -187,6 +212,7 @@ class KokoroV1(BaseModelBackend):
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
@ -197,9 +223,17 @@ class KokoroV1(BaseModelBackend):
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Generate using pipeline, force model to prevent downloads
logger.debug(f"Generating audio for text: '{text[:100]}...'")
for result in self._pipeline(text, voice=voice_path, speed=speed, model=self._model):
# Use provided lang_code or get from voice name
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'")
for result in pipeline(
text,
voice=voice_path,
speed=speed,
model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
@ -214,7 +248,7 @@ class KokoroV1(BaseModelBackend):
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate(text, voice, speed):
async for chunk in self.generate(text, voice, speed, lang_code):
yield chunk
raise
@ -236,9 +270,9 @@ class KokoroV1(BaseModelBackend):
if self._model is not None:
del self._model
self._model = None
if self._pipeline is not None:
del self._pipeline
self._pipeline = None
for pipeline in self._pipelines.values():
del pipeline
self._pipelines.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@ -246,7 +280,7 @@ class KokoroV1(BaseModelBackend):
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None and self._pipeline is not None
return self._model is not None
@property
def device(self) -> str:

View file

@ -72,7 +72,10 @@ class ModelManager:
# Warm up with short text
warmup_text = "Warmup text for initialization."
async for _ in self.generate(warmup_text, voice_path):
# Use default voice name for warmup
voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass
except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}")

View file

@ -28,6 +28,7 @@ def setup_logger():
"sink": sys.stdout,
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
"{level: <8} | "
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
"{message}",
"colorize": True,
"level": "DEBUG",
@ -88,6 +89,7 @@ async def lifespan(app: FastAPI):
# Add web player info if enabled
if settings.enable_web_player:
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
startup_msg += f"\nor http://localhost:{settings.port}/web/"
else:
startup_msg += "\n\nWeb Player: disabled"

View file

@ -130,11 +130,13 @@ async def stream_audio_chunks(
voice_name = await process_voices(request.voice, tts_service)
try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code,
):
# Check if client is still connected
is_disconnected = client_request.is_disconnected
@ -250,7 +252,8 @@ async def create_speech(
audio, _ = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
speed=request.speed
speed=request.speed,
lang_code=request.lang_code
)
# Convert to requested format with proper finalization

View file

@ -51,6 +51,7 @@ class TTSService:
is_first: bool = False,
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
"""Process tokens into audio."""
async with self._chunk_semaphore:
@ -82,11 +83,12 @@ class TTSService:
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
# For Kokoro V1, pass text and voice info
# For Kokoro V1, pass text and voice info with lang_code
async for chunk_audio in self.model_manager.generate(
chunk_text,
(voice_name, voice_path),
speed=speed
speed=speed,
lang_code=lang_code
):
# For streaming, convert to bytes
if output_format:
@ -217,6 +219,7 @@ class TTSService:
voice: str,
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
@ -230,6 +233,10 @@ class TTSService:
voice_name, voice_path = await self._get_voice_path(voice)
logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text):
try:
@ -243,7 +250,8 @@ class TTSService:
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer
normalizer=stream_normalizer,
lang_code=pipeline_lang_code # Pass lang_code
):
if result is not None:
yield result
@ -268,7 +276,8 @@ class TTSService:
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer
normalizer=stream_normalizer,
lang_code=pipeline_lang_code # Pass lang_code
):
if result is not None:
yield result
@ -280,7 +289,8 @@ class TTSService:
raise
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False,
lang_code: Optional[str] = None
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
@ -293,8 +303,12 @@ class TTSService:
voice_name, voice_path = await self._get_voice_path(voice)
if isinstance(backend, KokoroV1):
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking")
# Initialize quiet pipeline for text chunking
quiet_pipeline = KPipeline(lang_code='a', model=False)
quiet_pipeline = KPipeline(lang_code=pipeline_lang_code, model=False)
# Split text into chunks and get initial tokens
text_chunks = []
@ -310,12 +324,13 @@ class TTSService:
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
# Generate audio and timestamps for this chunk
for result in backend._pipeline(
# Create a new pipeline with the lang_code
generation_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in generation pipeline")
for result in generation_pipeline(
chunk_text,
voice=voice_path,
speed=speed,
model=backend._model
speed=speed
):
# Collect audio chunks
if result.audio is not None:
@ -470,7 +485,8 @@ class TTSService:
self,
phonemes: str,
voice: str,
speed: float = 1.0
speed: float = 1.0,
lang_code: Optional[str] = None
) -> Tuple[np.ndarray, float]:
"""Generate audio directly from phonemes.
@ -478,6 +494,7 @@ class TTSService:
phonemes: Phonemes in Kokoro format
voice: Voice name
speed: Speed multiplier
lang_code: Optional language code override
Returns:
Tuple of (audio array, processing time)
@ -491,11 +508,16 @@ class TTSService:
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes
result = None
for r in backend._pipeline.generate_from_tokens(
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
# Create a new pipeline with the lang_code
phoneme_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
for r in phoneme_pipeline.generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
speed=speed,
model=backend._model
speed=speed
):
if r.audio is not None:
result = r

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Literal, Union
from typing import List, Literal, Union, Optional
from pydantic import BaseModel, Field
@ -62,6 +62,10 @@ class OpenAISpeechRequest(BaseModel):
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
class CaptionedSpeechRequest(BaseModel):
"""Request schema for captioned speech endpoint"""
@ -88,3 +92,7 @@ class CaptionedSpeechRequest(BaseModel):
default=True,
description="If true (default), returns word-level timestamps in the response",
)
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)

View file

@ -1,5 +1,5 @@
import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, ANY
import torch
import numpy as np
from api.src.inference.kokoro_v1 import KokoroV1
@ -13,7 +13,7 @@ def test_initial_state(kokoro_backend):
"""Test initial state of KokoroV1."""
assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None
assert kokoro_backend._pipeline is None
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
# Device should be set based on settings
assert kokoro_backend.device in ["cuda", "cpu"]
@ -47,18 +47,23 @@ async def test_load_model_validation(kokoro_backend):
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
await kokoro_backend.load_model("nonexistent_model.pth")
def test_unload(kokoro_backend):
"""Test model unloading."""
# Mock loaded state
def test_unload_with_pipelines(kokoro_backend):
"""Test model unloading with multiple pipelines."""
# Mock loaded state with multiple pipelines
kokoro_backend._model = MagicMock()
kokoro_backend._pipeline = MagicMock()
pipeline_a = MagicMock()
pipeline_e = MagicMock()
kokoro_backend._pipelines = {
'a': pipeline_a,
'e': pipeline_e
}
assert kokoro_backend.is_loaded
# Test unload
kokoro_backend.unload()
assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None
assert kokoro_backend._pipeline is None
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
@pytest.mark.asyncio
async def test_generate_validation(kokoro_backend):
@ -72,4 +77,78 @@ async def test_generate_from_tokens_validation(kokoro_backend):
"""Test token generation validation."""
with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
pass
pass
def test_get_pipeline_creates_new(kokoro_backend):
"""Test that _get_pipeline creates new pipeline for new language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
# Get pipeline for Spanish
pipeline_e = kokoro_backend._get_pipeline('e')
# Should create new pipeline with correct params
mock_kpipeline.assert_called_once_with(
lang_code='e',
model=kokoro_backend._model,
device=kokoro_backend._device
)
assert pipeline_e == mock_pipeline
assert kokoro_backend._pipelines['e'] == mock_pipeline
def test_get_pipeline_reuses_existing(kokoro_backend):
"""Test that _get_pipeline reuses existing pipeline for same language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
# Get pipeline twice for same language
pipeline1 = kokoro_backend._get_pipeline('e')
pipeline2 = kokoro_backend._get_pipeline('e')
# Should only create pipeline once
mock_kpipeline.assert_called_once()
assert pipeline1 == pipeline2
assert kokoro_backend._pipelines['e'] == mock_pipeline
@pytest.mark.asyncio
async def test_generate_uses_correct_pipeline(kokoro_backend):
"""Test that generate uses correct pipeline for language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock voice path handling
with patch('api.src.core.paths.load_voice_tensor') as mock_load_voice, \
patch('api.src.core.paths.save_voice_tensor'), \
patch('tempfile.gettempdir') as mock_tempdir:
mock_load_voice.return_value = torch.ones(1)
mock_tempdir.return_value = "/tmp"
# Mock KPipeline
mock_pipeline = MagicMock()
mock_pipeline.return_value = iter([]) # Empty generator for testing
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline):
# Generate with Spanish voice and explicit lang_code
async for _ in kokoro_backend.generate(
"test", "ef_voice", lang_code='e'
):
pass
# Should create pipeline with Spanish lang_code
assert 'e' in kokoro_backend._pipelines
# Use ANY to match the temp file path since it's dynamic
mock_pipeline.assert_called_with(
"test",
voice=ANY, # Don't check exact path since it's dynamic
speed=1.0,
model=kokoro_backend._model
)
# Verify the voice path is a temp file path
call_args = mock_pipeline.call_args
assert isinstance(call_args[1]['voice'], str)
assert call_args[1]['voice'].startswith("/tmp/temp_voice_")

View file

@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y \
ffmpeg \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /usr/share/espeak-ng-data \
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
# Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \

View file

@ -94,6 +94,19 @@
<label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label>
<input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0">
</div>
<div class="lang-control">
<label for="lang-select">Language:</label>
<select id="lang-select" class="lang-select">
<option value="">Auto</option>
<option value="e">Spanish</option>
<option value="a">English</option>
<option value="f">French</option>
<option value="i">Italian</option>
<option value="p">Portuguese</option>
<option value="j">Japanese</option>
<option value="z">Chinese</option>
</select>
</div>
</div>
<div class="button-group">
<button id="generate-btn">

View file

@ -42,7 +42,8 @@ export class AudioService {
response_format: 'mp3',
stream: true,
speed: speed,
return_download_link: true
return_download_link: true,
lang_code: document.getElementById('lang-select').value || undefined
}),
signal: this.controller.signal
});

View file

@ -252,6 +252,49 @@
transform: scale(1.1);
}
/* Language Control */
.lang-control {
display: flex;
flex-direction: column;
gap: 0.75rem;
padding: 0.75rem;
background: rgba(15, 23, 42, 0.3);
border: 1px solid var(--border);
border-radius: 0.5rem;
}
.lang-control label {
color: var(--text-light);
font-size: 0.875rem;
}
.lang-select {
background: rgba(15, 23, 42, 0.3);
color: var(--text);
border: 1px solid var(--border);
border-radius: 0.5rem;
padding: 0.5rem 1rem;
font-family: var(--font-family);
font-size: 0.875rem;
cursor: pointer;
transition: all 0.2s ease;
}
.lang-select:hover {
border-color: var(--fg-color);
}
.lang-select:focus {
outline: none;
border-color: var(--fg-color);
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
}
.lang-select option {
background: var(--surface);
color: var(--text);
}
/* Generation Controls */
.button-group {
display: flex;