-fix voice selection not matching language phonemes

-added voice language override parameter
This commit is contained in:
remsky 2025-02-08 01:29:15 -07:00
parent 68cc14896a
commit a0dc870f4a
12 changed files with 259 additions and 51 deletions

View file

@ -3,8 +3,8 @@
</p> </p>
# <sub><sub>_`FastKoko`_ </sub></sub> # <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-63%20passed-darkgreen)]() [![Tests](https://img.shields.io/badge/tests-66%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-53%25-tan)]() [![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]()
[![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6) [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)

View file

@ -1,7 +1,7 @@
"""Clean Kokoro implementation with controlled resource management.""" """Clean Kokoro implementation with controlled resource management."""
import os import os
from typing import AsyncGenerator, Optional, Union, Tuple from typing import AsyncGenerator, Optional, Union, Tuple, Dict
import numpy as np import numpy as np
import torch import torch
@ -22,7 +22,7 @@ class KokoroV1(BaseModelBackend):
# Strictly respect settings.use_gpu # Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu" self._device = "cuda" if settings.use_gpu else "cpu"
self._model: Optional[KModel] = None self._model: Optional[KModel] = None
self._pipeline: Optional[KPipeline] = None self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
async def load_model(self, path: str) -> None: async def load_model(self, path: str) -> None:
"""Load pre-baked model. """Load pre-baked model.
@ -54,22 +54,38 @@ class KokoroV1(BaseModelBackend):
if self._device == "cuda": if self._device == "cuda":
self._model = self._model.cuda() self._model = self._model.cuda()
# Initialize pipeline with our model and device
self._pipeline = KPipeline(
lang_code='a',
model=self._model, # Pass our model directly
device=self._device # Match our device setting
)
except FileNotFoundError as e: except FileNotFoundError as e:
raise e raise e
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}") raise RuntimeError(f"Failed to load Kokoro model: {e}")
def _get_pipeline(self, lang_code: str) -> KPipeline:
"""Get or create pipeline for language code.
Args:
lang_code: Language code to use
Returns:
KPipeline instance for the language
"""
if not self._model:
raise RuntimeError("Model not loaded")
if lang_code not in self._pipelines:
logger.info(f"Creating new pipeline for language code: {lang_code}")
self._pipelines[lang_code] = KPipeline(
lang_code=lang_code,
model=self._model,
device=self._device
)
return self._pipelines[lang_code]
async def generate_from_tokens( async def generate_from_tokens(
self, self,
tokens: str, tokens: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0 speed: float = 1.0,
lang_code: Optional[str] = None
) -> AsyncGenerator[np.ndarray, None]: ) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from phoneme tokens. """Generate audio from phoneme tokens.
@ -77,6 +93,7 @@ class KokoroV1(BaseModelBackend):
tokens: Input phoneme tokens to synthesize tokens: Input phoneme tokens to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path) voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier speed: Speed multiplier
lang_code: Optional language code override
Yields: Yields:
Generated audio chunks Generated audio chunks
@ -95,6 +112,7 @@ class KokoroV1(BaseModelBackend):
# Handle voice input # Handle voice input
voice_path: str voice_path: str
voice_name: str
if isinstance(voice, tuple): if isinstance(voice, tuple):
voice_name, voice_data = voice voice_name, voice_data = voice
if isinstance(voice_data, str): if isinstance(voice_data, str):
@ -108,6 +126,7 @@ class KokoroV1(BaseModelBackend):
torch.save(voice_data.cpu(), voice_path) torch.save(voice_data.cpu(), voice_path)
else: else:
voice_path = voice voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping # Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device) voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
@ -118,9 +137,12 @@ class KokoroV1(BaseModelBackend):
await paths.save_voice_tensor(voice_tensor, temp_path) await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path voice_path = temp_path
# Generate using pipeline's generate_from_tokens method # Use provided lang_code or get from voice name
logger.debug(f"Generating audio from tokens: '{tokens[:100]}...'") pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
for result in self._pipeline.generate_from_tokens( pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'")
for result in pipeline.generate_from_tokens(
tokens=tokens, tokens=tokens,
voice=voice_path, voice=voice_path,
speed=speed, speed=speed,
@ -140,7 +162,7 @@ class KokoroV1(BaseModelBackend):
and "out of memory" in str(e).lower() and "out of memory" in str(e).lower()
): ):
self._clear_memory() self._clear_memory()
async for chunk in self.generate_from_tokens(tokens, voice, speed): async for chunk in self.generate_from_tokens(tokens, voice, speed, lang_code):
yield chunk yield chunk
raise raise
@ -148,7 +170,8 @@ class KokoroV1(BaseModelBackend):
self, self,
text: str, text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0 speed: float = 1.0,
lang_code: Optional[str] = None
) -> AsyncGenerator[np.ndarray, None]: ) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio using model. """Generate audio using model.
@ -156,6 +179,7 @@ class KokoroV1(BaseModelBackend):
text: Input text to synthesize text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path) voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier speed: Speed multiplier
lang_code: Optional language code override
Yields: Yields:
Generated audio chunks Generated audio chunks
@ -174,6 +198,7 @@ class KokoroV1(BaseModelBackend):
# Handle voice input # Handle voice input
voice_path: str voice_path: str
voice_name: str
if isinstance(voice, tuple): if isinstance(voice, tuple):
voice_name, voice_data = voice voice_name, voice_data = voice
if isinstance(voice_data, str): if isinstance(voice_data, str):
@ -187,6 +212,7 @@ class KokoroV1(BaseModelBackend):
torch.save(voice_data.cpu(), voice_path) torch.save(voice_data.cpu(), voice_path)
else: else:
voice_path = voice voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping # Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device) voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
@ -197,9 +223,17 @@ class KokoroV1(BaseModelBackend):
await paths.save_voice_tensor(voice_tensor, temp_path) await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path voice_path = temp_path
# Generate using pipeline, force model to prevent downloads # Use provided lang_code or get from voice name
logger.debug(f"Generating audio for text: '{text[:100]}...'") pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
for result in self._pipeline(text, voice=voice_path, speed=speed, model=self._model): pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'")
for result in pipeline(
text,
voice=voice_path,
speed=speed,
model=self._model
):
if result.audio is not None: if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}") logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy() yield result.audio.numpy()
@ -214,7 +248,7 @@ class KokoroV1(BaseModelBackend):
and "out of memory" in str(e).lower() and "out of memory" in str(e).lower()
): ):
self._clear_memory() self._clear_memory()
async for chunk in self.generate(text, voice, speed): async for chunk in self.generate(text, voice, speed, lang_code):
yield chunk yield chunk
raise raise
@ -236,9 +270,9 @@ class KokoroV1(BaseModelBackend):
if self._model is not None: if self._model is not None:
del self._model del self._model
self._model = None self._model = None
if self._pipeline is not None: for pipeline in self._pipelines.values():
del self._pipeline del pipeline
self._pipeline = None self._pipelines.clear()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
@ -246,7 +280,7 @@ class KokoroV1(BaseModelBackend):
@property @property
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""Check if model is loaded.""" """Check if model is loaded."""
return self._model is not None and self._pipeline is not None return self._model is not None
@property @property
def device(self) -> str: def device(self) -> str:

View file

@ -72,7 +72,10 @@ class ModelManager:
# Warm up with short text # Warm up with short text
warmup_text = "Warmup text for initialization." warmup_text = "Warmup text for initialization."
async for _ in self.generate(warmup_text, voice_path): # Use default voice name for warmup
voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass pass
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}") raise RuntimeError(f"Failed to get default voice: {e}")

View file

@ -28,6 +28,7 @@ def setup_logger():
"sink": sys.stdout, "sink": sys.stdout,
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | " "format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
"{level: <8} | " "{level: <8} | "
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
"{message}", "{message}",
"colorize": True, "colorize": True,
"level": "DEBUG", "level": "DEBUG",
@ -88,6 +89,7 @@ async def lifespan(app: FastAPI):
# Add web player info if enabled # Add web player info if enabled
if settings.enable_web_player: if settings.enable_web_player:
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/" startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
startup_msg += f"\nor http://localhost:{settings.port}/web/"
else: else:
startup_msg += "\n\nWeb Player: disabled" startup_msg += "\n\nWeb Player: disabled"

View file

@ -130,11 +130,13 @@ async def stream_audio_chunks(
voice_name = await process_voices(request.voice, tts_service) voice_name = await process_voices(request.voice, tts_service)
try: try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk in tts_service.generate_audio_stream( async for chunk in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code,
): ):
# Check if client is still connected # Check if client is still connected
is_disconnected = client_request.is_disconnected is_disconnected = client_request.is_disconnected
@ -250,7 +252,8 @@ async def create_speech(
audio, _ = await tts_service.generate_audio( audio, _ = await tts_service.generate_audio(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed speed=request.speed,
lang_code=request.lang_code
) )
# Convert to requested format with proper finalization # Convert to requested format with proper finalization

View file

@ -51,6 +51,7 @@ class TTSService:
is_first: bool = False, is_first: bool = False,
is_last: bool = False, is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
) -> AsyncGenerator[Union[np.ndarray, bytes], None]: ) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
"""Process tokens into audio.""" """Process tokens into audio."""
async with self._chunk_semaphore: async with self._chunk_semaphore:
@ -82,11 +83,12 @@ class TTSService:
# Generate audio using pre-warmed model # Generate audio using pre-warmed model
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
# For Kokoro V1, pass text and voice info # For Kokoro V1, pass text and voice info with lang_code
async for chunk_audio in self.model_manager.generate( async for chunk_audio in self.model_manager.generate(
chunk_text, chunk_text,
(voice_name, voice_path), (voice_name, voice_path),
speed=speed speed=speed,
lang_code=lang_code
): ):
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
@ -217,6 +219,7 @@ class TTSService:
voice: str, voice: str,
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
@ -230,6 +233,10 @@ class TTSService:
voice_name, voice_path = await self._get_voice_path(voice) voice_name, voice_path = await self._get_voice_path(voice)
logger.debug(f"Using voice path: {voice_path}") logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
# Process text in chunks with smart splitting # Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text): async for chunk_text, tokens in smart_split(text):
try: try:
@ -243,7 +250,8 @@ class TTSService:
output_format, output_format,
is_first=(chunk_index == 0), is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer normalizer=stream_normalizer,
lang_code=pipeline_lang_code # Pass lang_code
): ):
if result is not None: if result is not None:
yield result yield result
@ -268,7 +276,8 @@ class TTSService:
output_format, output_format,
is_first=False, is_first=False,
is_last=True, # Signal this is the last chunk is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer normalizer=stream_normalizer,
lang_code=pipeline_lang_code # Pass lang_code
): ):
if result is not None: if result is not None:
yield result yield result
@ -280,7 +289,8 @@ class TTSService:
raise raise
async def generate_audio( async def generate_audio(
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False,
lang_code: Optional[str] = None
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]: ) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
"""Generate complete audio for text using streaming internally.""" """Generate complete audio for text using streaming internally."""
start_time = time.time() start_time = time.time()
@ -293,8 +303,12 @@ class TTSService:
voice_name, voice_path = await self._get_voice_path(voice) voice_name, voice_path = await self._get_voice_path(voice)
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking")
# Initialize quiet pipeline for text chunking # Initialize quiet pipeline for text chunking
quiet_pipeline = KPipeline(lang_code='a', model=False) quiet_pipeline = KPipeline(lang_code=pipeline_lang_code, model=False)
# Split text into chunks and get initial tokens # Split text into chunks and get initial tokens
text_chunks = [] text_chunks = []
@ -310,12 +324,13 @@ class TTSService:
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks): for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'") logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
# Generate audio and timestamps for this chunk # Create a new pipeline with the lang_code
for result in backend._pipeline( generation_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in generation pipeline")
for result in generation_pipeline(
chunk_text, chunk_text,
voice=voice_path, voice=voice_path,
speed=speed, speed=speed
model=backend._model
): ):
# Collect audio chunks # Collect audio chunks
if result.audio is not None: if result.audio is not None:
@ -470,7 +485,8 @@ class TTSService:
self, self,
phonemes: str, phonemes: str,
voice: str, voice: str,
speed: float = 1.0 speed: float = 1.0,
lang_code: Optional[str] = None
) -> Tuple[np.ndarray, float]: ) -> Tuple[np.ndarray, float]:
"""Generate audio directly from phonemes. """Generate audio directly from phonemes.
@ -478,6 +494,7 @@ class TTSService:
phonemes: Phonemes in Kokoro format phonemes: Phonemes in Kokoro format
voice: Voice name voice: Voice name
speed: Speed multiplier speed: Speed multiplier
lang_code: Optional language code override
Returns: Returns:
Tuple of (audio array, processing time) Tuple of (audio array, processing time)
@ -491,11 +508,16 @@ class TTSService:
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes # For Kokoro V1, use generate_from_tokens with raw phonemes
result = None result = None
for r in backend._pipeline.generate_from_tokens( # Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
# Create a new pipeline with the lang_code
phoneme_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
for r in phoneme_pipeline.generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string tokens=phonemes, # Pass raw phonemes string
voice=voice_path, voice=voice_path,
speed=speed, speed=speed
model=backend._model
): ):
if r.audio is not None: if r.audio is not None:
result = r result = r

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Literal, Union from typing import List, Literal, Union, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -62,6 +62,10 @@ class OpenAISpeechRequest(BaseModel):
default=False, default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes", description="If true, returns a download link in X-Download-Path header after streaming completes",
) )
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
class CaptionedSpeechRequest(BaseModel): class CaptionedSpeechRequest(BaseModel):
"""Request schema for captioned speech endpoint""" """Request schema for captioned speech endpoint"""
@ -88,3 +92,7 @@ class CaptionedSpeechRequest(BaseModel):
default=True, default=True,
description="If true (default), returns word-level timestamps in the response", description="If true (default), returns word-level timestamps in the response",
) )
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)

View file

@ -1,5 +1,5 @@
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock, ANY
import torch import torch
import numpy as np import numpy as np
from api.src.inference.kokoro_v1 import KokoroV1 from api.src.inference.kokoro_v1 import KokoroV1
@ -13,7 +13,7 @@ def test_initial_state(kokoro_backend):
"""Test initial state of KokoroV1.""" """Test initial state of KokoroV1."""
assert not kokoro_backend.is_loaded assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None assert kokoro_backend._model is None
assert kokoro_backend._pipeline is None assert kokoro_backend._pipelines == {} # Now using dict of pipelines
# Device should be set based on settings # Device should be set based on settings
assert kokoro_backend.device in ["cuda", "cpu"] assert kokoro_backend.device in ["cuda", "cpu"]
@ -47,18 +47,23 @@ async def test_load_model_validation(kokoro_backend):
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"): with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
await kokoro_backend.load_model("nonexistent_model.pth") await kokoro_backend.load_model("nonexistent_model.pth")
def test_unload(kokoro_backend): def test_unload_with_pipelines(kokoro_backend):
"""Test model unloading.""" """Test model unloading with multiple pipelines."""
# Mock loaded state # Mock loaded state with multiple pipelines
kokoro_backend._model = MagicMock() kokoro_backend._model = MagicMock()
kokoro_backend._pipeline = MagicMock() pipeline_a = MagicMock()
pipeline_e = MagicMock()
kokoro_backend._pipelines = {
'a': pipeline_a,
'e': pipeline_e
}
assert kokoro_backend.is_loaded assert kokoro_backend.is_loaded
# Test unload # Test unload
kokoro_backend.unload() kokoro_backend.unload()
assert not kokoro_backend.is_loaded assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None assert kokoro_backend._model is None
assert kokoro_backend._pipeline is None assert kokoro_backend._pipelines == {} # All pipelines should be cleared
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_validation(kokoro_backend): async def test_generate_validation(kokoro_backend):
@ -73,3 +78,77 @@ async def test_generate_from_tokens_validation(kokoro_backend):
with pytest.raises(RuntimeError, match="Model not loaded"): with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"): async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
pass pass
def test_get_pipeline_creates_new(kokoro_backend):
"""Test that _get_pipeline creates new pipeline for new language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
# Get pipeline for Spanish
pipeline_e = kokoro_backend._get_pipeline('e')
# Should create new pipeline with correct params
mock_kpipeline.assert_called_once_with(
lang_code='e',
model=kokoro_backend._model,
device=kokoro_backend._device
)
assert pipeline_e == mock_pipeline
assert kokoro_backend._pipelines['e'] == mock_pipeline
def test_get_pipeline_reuses_existing(kokoro_backend):
"""Test that _get_pipeline reuses existing pipeline for same language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
# Get pipeline twice for same language
pipeline1 = kokoro_backend._get_pipeline('e')
pipeline2 = kokoro_backend._get_pipeline('e')
# Should only create pipeline once
mock_kpipeline.assert_called_once()
assert pipeline1 == pipeline2
assert kokoro_backend._pipelines['e'] == mock_pipeline
@pytest.mark.asyncio
async def test_generate_uses_correct_pipeline(kokoro_backend):
"""Test that generate uses correct pipeline for language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock voice path handling
with patch('api.src.core.paths.load_voice_tensor') as mock_load_voice, \
patch('api.src.core.paths.save_voice_tensor'), \
patch('tempfile.gettempdir') as mock_tempdir:
mock_load_voice.return_value = torch.ones(1)
mock_tempdir.return_value = "/tmp"
# Mock KPipeline
mock_pipeline = MagicMock()
mock_pipeline.return_value = iter([]) # Empty generator for testing
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline):
# Generate with Spanish voice and explicit lang_code
async for _ in kokoro_backend.generate(
"test", "ef_voice", lang_code='e'
):
pass
# Should create pipeline with Spanish lang_code
assert 'e' in kokoro_backend._pipelines
# Use ANY to match the temp file path since it's dynamic
mock_pipeline.assert_called_with(
"test",
voice=ANY, # Don't check exact path since it's dynamic
speed=1.0,
model=kokoro_backend._model
)
# Verify the voice path is a temp file path
call_args = mock_pipeline.call_args
assert isinstance(call_args[1]['voice'], str)
assert call_args[1]['voice'].startswith("/tmp/temp_voice_")

View file

@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y \
ffmpeg \ ffmpeg \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \ && apt-get clean && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /usr/share/espeak-ng-data \ && mkdir -p /usr/share/espeak-ng-data \
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/ && ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
# Install UV using the installer script # Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \

View file

@ -94,6 +94,19 @@
<label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label> <label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label>
<input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0"> <input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0">
</div> </div>
<div class="lang-control">
<label for="lang-select">Language:</label>
<select id="lang-select" class="lang-select">
<option value="">Auto</option>
<option value="e">Spanish</option>
<option value="a">English</option>
<option value="f">French</option>
<option value="i">Italian</option>
<option value="p">Portuguese</option>
<option value="j">Japanese</option>
<option value="z">Chinese</option>
</select>
</div>
</div> </div>
<div class="button-group"> <div class="button-group">
<button id="generate-btn"> <button id="generate-btn">

View file

@ -42,7 +42,8 @@ export class AudioService {
response_format: 'mp3', response_format: 'mp3',
stream: true, stream: true,
speed: speed, speed: speed,
return_download_link: true return_download_link: true,
lang_code: document.getElementById('lang-select').value || undefined
}), }),
signal: this.controller.signal signal: this.controller.signal
}); });

View file

@ -252,6 +252,49 @@
transform: scale(1.1); transform: scale(1.1);
} }
/* Language Control */
.lang-control {
display: flex;
flex-direction: column;
gap: 0.75rem;
padding: 0.75rem;
background: rgba(15, 23, 42, 0.3);
border: 1px solid var(--border);
border-radius: 0.5rem;
}
.lang-control label {
color: var(--text-light);
font-size: 0.875rem;
}
.lang-select {
background: rgba(15, 23, 42, 0.3);
color: var(--text);
border: 1px solid var(--border);
border-radius: 0.5rem;
padding: 0.5rem 1rem;
font-family: var(--font-family);
font-size: 0.875rem;
cursor: pointer;
transition: all 0.2s ease;
}
.lang-select:hover {
border-color: var(--fg-color);
}
.lang-select:focus {
outline: none;
border-color: var(--fg-color);
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
}
.lang-select option {
background: var(--surface);
color: var(--text);
}
/* Generation Controls */ /* Generation Controls */
.button-group { .button-group {
display: flex; display: flex;