mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
-fix voice selection not matching language phonemes
-added voice language override parameter
This commit is contained in:
parent
68cc14896a
commit
a0dc870f4a
12 changed files with 259 additions and 51 deletions
|
@ -3,8 +3,8 @@
|
|||
</p>
|
||||
|
||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Clean Kokoro implementation with controlled resource management."""
|
||||
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional, Union, Tuple
|
||||
from typing import AsyncGenerator, Optional, Union, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -22,7 +22,7 @@ class KokoroV1(BaseModelBackend):
|
|||
# Strictly respect settings.use_gpu
|
||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||
self._model: Optional[KModel] = None
|
||||
self._pipeline: Optional[KPipeline] = None
|
||||
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load pre-baked model.
|
||||
|
@ -54,22 +54,38 @@ class KokoroV1(BaseModelBackend):
|
|||
if self._device == "cuda":
|
||||
self._model = self._model.cuda()
|
||||
|
||||
# Initialize pipeline with our model and device
|
||||
self._pipeline = KPipeline(
|
||||
lang_code='a',
|
||||
model=self._model, # Pass our model directly
|
||||
device=self._device # Match our device setting
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
||||
|
||||
def _get_pipeline(self, lang_code: str) -> KPipeline:
|
||||
"""Get or create pipeline for language code.
|
||||
|
||||
Args:
|
||||
lang_code: Language code to use
|
||||
|
||||
Returns:
|
||||
KPipeline instance for the language
|
||||
"""
|
||||
if not self._model:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
if lang_code not in self._pipelines:
|
||||
logger.info(f"Creating new pipeline for language code: {lang_code}")
|
||||
self._pipelines[lang_code] = KPipeline(
|
||||
lang_code=lang_code,
|
||||
model=self._model,
|
||||
device=self._device
|
||||
)
|
||||
return self._pipelines[lang_code]
|
||||
|
||||
async def generate_from_tokens(
|
||||
self,
|
||||
tokens: str,
|
||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0
|
||||
speed: float = 1.0,
|
||||
lang_code: Optional[str] = None
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generate audio from phoneme tokens.
|
||||
|
||||
|
@ -77,6 +93,7 @@ class KokoroV1(BaseModelBackend):
|
|||
tokens: Input phoneme tokens to synthesize
|
||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
||||
speed: Speed multiplier
|
||||
lang_code: Optional language code override
|
||||
|
||||
Yields:
|
||||
Generated audio chunks
|
||||
|
@ -95,6 +112,7 @@ class KokoroV1(BaseModelBackend):
|
|||
|
||||
# Handle voice input
|
||||
voice_path: str
|
||||
voice_name: str
|
||||
if isinstance(voice, tuple):
|
||||
voice_name, voice_data = voice
|
||||
if isinstance(voice_data, str):
|
||||
|
@ -108,6 +126,7 @@ class KokoroV1(BaseModelBackend):
|
|||
torch.save(voice_data.cpu(), voice_path)
|
||||
else:
|
||||
voice_path = voice
|
||||
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
||||
|
||||
# Load voice tensor with proper device mapping
|
||||
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
||||
|
@ -118,9 +137,12 @@ class KokoroV1(BaseModelBackend):
|
|||
await paths.save_voice_tensor(voice_tensor, temp_path)
|
||||
voice_path = temp_path
|
||||
|
||||
# Generate using pipeline's generate_from_tokens method
|
||||
logger.debug(f"Generating audio from tokens: '{tokens[:100]}...'")
|
||||
for result in self._pipeline.generate_from_tokens(
|
||||
# Use provided lang_code or get from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
|
||||
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||
|
||||
logger.debug(f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'")
|
||||
for result in pipeline.generate_from_tokens(
|
||||
tokens=tokens,
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
|
@ -140,7 +162,7 @@ class KokoroV1(BaseModelBackend):
|
|||
and "out of memory" in str(e).lower()
|
||||
):
|
||||
self._clear_memory()
|
||||
async for chunk in self.generate_from_tokens(tokens, voice, speed):
|
||||
async for chunk in self.generate_from_tokens(tokens, voice, speed, lang_code):
|
||||
yield chunk
|
||||
raise
|
||||
|
||||
|
@ -148,7 +170,8 @@ class KokoroV1(BaseModelBackend):
|
|||
self,
|
||||
text: str,
|
||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0
|
||||
speed: float = 1.0,
|
||||
lang_code: Optional[str] = None
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generate audio using model.
|
||||
|
||||
|
@ -156,6 +179,7 @@ class KokoroV1(BaseModelBackend):
|
|||
text: Input text to synthesize
|
||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
||||
speed: Speed multiplier
|
||||
lang_code: Optional language code override
|
||||
|
||||
Yields:
|
||||
Generated audio chunks
|
||||
|
@ -174,6 +198,7 @@ class KokoroV1(BaseModelBackend):
|
|||
|
||||
# Handle voice input
|
||||
voice_path: str
|
||||
voice_name: str
|
||||
if isinstance(voice, tuple):
|
||||
voice_name, voice_data = voice
|
||||
if isinstance(voice_data, str):
|
||||
|
@ -187,6 +212,7 @@ class KokoroV1(BaseModelBackend):
|
|||
torch.save(voice_data.cpu(), voice_path)
|
||||
else:
|
||||
voice_path = voice
|
||||
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
||||
|
||||
# Load voice tensor with proper device mapping
|
||||
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
||||
|
@ -197,9 +223,17 @@ class KokoroV1(BaseModelBackend):
|
|||
await paths.save_voice_tensor(voice_tensor, temp_path)
|
||||
voice_path = temp_path
|
||||
|
||||
# Generate using pipeline, force model to prevent downloads
|
||||
logger.debug(f"Generating audio for text: '{text[:100]}...'")
|
||||
for result in self._pipeline(text, voice=voice_path, speed=speed, model=self._model):
|
||||
# Use provided lang_code or get from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
|
||||
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||
|
||||
logger.debug(f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'")
|
||||
for result in pipeline(
|
||||
text,
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
model=self._model
|
||||
):
|
||||
if result.audio is not None:
|
||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||
yield result.audio.numpy()
|
||||
|
@ -214,7 +248,7 @@ class KokoroV1(BaseModelBackend):
|
|||
and "out of memory" in str(e).lower()
|
||||
):
|
||||
self._clear_memory()
|
||||
async for chunk in self.generate(text, voice, speed):
|
||||
async for chunk in self.generate(text, voice, speed, lang_code):
|
||||
yield chunk
|
||||
raise
|
||||
|
||||
|
@ -236,9 +270,9 @@ class KokoroV1(BaseModelBackend):
|
|||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
if self._pipeline is not None:
|
||||
del self._pipeline
|
||||
self._pipeline = None
|
||||
for pipeline in self._pipelines.values():
|
||||
del pipeline
|
||||
self._pipelines.clear()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
@ -246,7 +280,7 @@ class KokoroV1(BaseModelBackend):
|
|||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._model is not None and self._pipeline is not None
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
|
|
|
@ -72,7 +72,10 @@ class ModelManager:
|
|||
|
||||
# Warm up with short text
|
||||
warmup_text = "Warmup text for initialization."
|
||||
async for _ in self.generate(warmup_text, voice_path):
|
||||
# Use default voice name for warmup
|
||||
voice_name = settings.default_voice
|
||||
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
||||
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
||||
pass
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get default voice: {e}")
|
||||
|
|
|
@ -28,6 +28,7 @@ def setup_logger():
|
|||
"sink": sys.stdout,
|
||||
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
||||
"{level: <8} | "
|
||||
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
|
||||
"{message}",
|
||||
"colorize": True,
|
||||
"level": "DEBUG",
|
||||
|
@ -88,6 +89,7 @@ async def lifespan(app: FastAPI):
|
|||
# Add web player info if enabled
|
||||
if settings.enable_web_player:
|
||||
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
||||
startup_msg += f"\nor http://localhost:{settings.port}/web/"
|
||||
else:
|
||||
startup_msg += "\n\nWeb Player: disabled"
|
||||
|
||||
|
|
|
@ -130,11 +130,13 @@ async def stream_audio_chunks(
|
|||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code,
|
||||
):
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
|
@ -250,7 +252,8 @@ async def create_speech(
|
|||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed
|
||||
speed=request.speed,
|
||||
lang_code=request.lang_code
|
||||
)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
|
|
|
@ -51,6 +51,7 @@ class TTSService:
|
|||
is_first: bool = False,
|
||||
is_last: bool = False,
|
||||
normalizer: Optional[AudioNormalizer] = None,
|
||||
lang_code: Optional[str] = None,
|
||||
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
||||
"""Process tokens into audio."""
|
||||
async with self._chunk_semaphore:
|
||||
|
@ -82,11 +83,12 @@ class TTSService:
|
|||
|
||||
# Generate audio using pre-warmed model
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, pass text and voice info
|
||||
# For Kokoro V1, pass text and voice info with lang_code
|
||||
async for chunk_audio in self.model_manager.generate(
|
||||
chunk_text,
|
||||
(voice_name, voice_path),
|
||||
speed=speed
|
||||
speed=speed,
|
||||
lang_code=lang_code
|
||||
):
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
|
@ -217,6 +219,7 @@ class TTSService:
|
|||
voice: str,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
lang_code: Optional[str] = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
|
@ -230,6 +233,10 @@ class TTSService:
|
|||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
logger.debug(f"Using voice path: {voice_path}")
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text):
|
||||
try:
|
||||
|
@ -243,7 +250,8 @@ class TTSService:
|
|||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
normalizer=stream_normalizer
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code # Pass lang_code
|
||||
):
|
||||
if result is not None:
|
||||
yield result
|
||||
|
@ -268,7 +276,8 @@ class TTSService:
|
|||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
normalizer=stream_normalizer
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code # Pass lang_code
|
||||
):
|
||||
if result is not None:
|
||||
yield result
|
||||
|
@ -280,7 +289,8 @@ class TTSService:
|
|||
raise
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
|
||||
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False,
|
||||
lang_code: Optional[str] = None
|
||||
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
|
@ -293,8 +303,12 @@ class TTSService:
|
|||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking")
|
||||
|
||||
# Initialize quiet pipeline for text chunking
|
||||
quiet_pipeline = KPipeline(lang_code='a', model=False)
|
||||
quiet_pipeline = KPipeline(lang_code=pipeline_lang_code, model=False)
|
||||
|
||||
# Split text into chunks and get initial tokens
|
||||
text_chunks = []
|
||||
|
@ -310,12 +324,13 @@ class TTSService:
|
|||
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
|
||||
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
|
||||
|
||||
# Generate audio and timestamps for this chunk
|
||||
for result in backend._pipeline(
|
||||
# Create a new pipeline with the lang_code
|
||||
generation_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in generation pipeline")
|
||||
for result in generation_pipeline(
|
||||
chunk_text,
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
model=backend._model
|
||||
speed=speed
|
||||
):
|
||||
# Collect audio chunks
|
||||
if result.audio is not None:
|
||||
|
@ -470,7 +485,8 @@ class TTSService:
|
|||
self,
|
||||
phonemes: str,
|
||||
voice: str,
|
||||
speed: float = 1.0
|
||||
speed: float = 1.0,
|
||||
lang_code: Optional[str] = None
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio directly from phonemes.
|
||||
|
||||
|
@ -478,6 +494,7 @@ class TTSService:
|
|||
phonemes: Phonemes in Kokoro format
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
lang_code: Optional language code override
|
||||
|
||||
Returns:
|
||||
Tuple of (audio array, processing time)
|
||||
|
@ -491,11 +508,16 @@ class TTSService:
|
|||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||
result = None
|
||||
for r in backend._pipeline.generate_from_tokens(
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
|
||||
|
||||
# Create a new pipeline with the lang_code
|
||||
phoneme_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
|
||||
for r in phoneme_pipeline.generate_from_tokens(
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
model=backend._model
|
||||
speed=speed
|
||||
):
|
||||
if r.audio is not None:
|
||||
result = r
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import List, Literal, Union
|
||||
from typing import List, Literal, Union, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -62,6 +62,10 @@ class OpenAISpeechRequest(BaseModel):
|
|||
default=False,
|
||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||
)
|
||||
lang_code: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||
)
|
||||
|
||||
class CaptionedSpeechRequest(BaseModel):
|
||||
"""Request schema for captioned speech endpoint"""
|
||||
|
@ -88,3 +92,7 @@ class CaptionedSpeechRequest(BaseModel):
|
|||
default=True,
|
||||
description="If true (default), returns word-level timestamps in the response",
|
||||
)
|
||||
lang_code: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import torch
|
||||
import numpy as np
|
||||
from api.src.inference.kokoro_v1 import KokoroV1
|
||||
|
@ -13,7 +13,7 @@ def test_initial_state(kokoro_backend):
|
|||
"""Test initial state of KokoroV1."""
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipeline is None
|
||||
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
|
||||
# Device should be set based on settings
|
||||
assert kokoro_backend.device in ["cuda", "cpu"]
|
||||
|
||||
|
@ -47,18 +47,23 @@ async def test_load_model_validation(kokoro_backend):
|
|||
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
||||
await kokoro_backend.load_model("nonexistent_model.pth")
|
||||
|
||||
def test_unload(kokoro_backend):
|
||||
"""Test model unloading."""
|
||||
# Mock loaded state
|
||||
def test_unload_with_pipelines(kokoro_backend):
|
||||
"""Test model unloading with multiple pipelines."""
|
||||
# Mock loaded state with multiple pipelines
|
||||
kokoro_backend._model = MagicMock()
|
||||
kokoro_backend._pipeline = MagicMock()
|
||||
pipeline_a = MagicMock()
|
||||
pipeline_e = MagicMock()
|
||||
kokoro_backend._pipelines = {
|
||||
'a': pipeline_a,
|
||||
'e': pipeline_e
|
||||
}
|
||||
assert kokoro_backend.is_loaded
|
||||
|
||||
# Test unload
|
||||
kokoro_backend.unload()
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipeline is None
|
||||
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_validation(kokoro_backend):
|
||||
|
@ -72,4 +77,78 @@ async def test_generate_from_tokens_validation(kokoro_backend):
|
|||
"""Test token generation validation."""
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
||||
pass
|
||||
pass
|
||||
|
||||
def test_get_pipeline_creates_new(kokoro_backend):
|
||||
"""Test that _get_pipeline creates new pipeline for new language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
|
||||
# Get pipeline for Spanish
|
||||
pipeline_e = kokoro_backend._get_pipeline('e')
|
||||
|
||||
# Should create new pipeline with correct params
|
||||
mock_kpipeline.assert_called_once_with(
|
||||
lang_code='e',
|
||||
model=kokoro_backend._model,
|
||||
device=kokoro_backend._device
|
||||
)
|
||||
assert pipeline_e == mock_pipeline
|
||||
assert kokoro_backend._pipelines['e'] == mock_pipeline
|
||||
|
||||
def test_get_pipeline_reuses_existing(kokoro_backend):
|
||||
"""Test that _get_pipeline reuses existing pipeline for same language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
|
||||
# Get pipeline twice for same language
|
||||
pipeline1 = kokoro_backend._get_pipeline('e')
|
||||
pipeline2 = kokoro_backend._get_pipeline('e')
|
||||
|
||||
# Should only create pipeline once
|
||||
mock_kpipeline.assert_called_once()
|
||||
assert pipeline1 == pipeline2
|
||||
assert kokoro_backend._pipelines['e'] == mock_pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_uses_correct_pipeline(kokoro_backend):
|
||||
"""Test that generate uses correct pipeline for language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock voice path handling
|
||||
with patch('api.src.core.paths.load_voice_tensor') as mock_load_voice, \
|
||||
patch('api.src.core.paths.save_voice_tensor'), \
|
||||
patch('tempfile.gettempdir') as mock_tempdir:
|
||||
mock_load_voice.return_value = torch.ones(1)
|
||||
mock_tempdir.return_value = "/tmp"
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.return_value = iter([]) # Empty generator for testing
|
||||
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline):
|
||||
# Generate with Spanish voice and explicit lang_code
|
||||
async for _ in kokoro_backend.generate(
|
||||
"test", "ef_voice", lang_code='e'
|
||||
):
|
||||
pass
|
||||
|
||||
# Should create pipeline with Spanish lang_code
|
||||
assert 'e' in kokoro_backend._pipelines
|
||||
# Use ANY to match the temp file path since it's dynamic
|
||||
mock_pipeline.assert_called_with(
|
||||
"test",
|
||||
voice=ANY, # Don't check exact path since it's dynamic
|
||||
speed=1.0,
|
||||
model=kokoro_backend._model
|
||||
)
|
||||
# Verify the voice path is a temp file path
|
||||
call_args = mock_pipeline.call_args
|
||||
assert isinstance(call_args[1]['voice'], str)
|
||||
assert call_args[1]['voice'].startswith("/tmp/temp_voice_")
|
|
@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y \
|
|||
ffmpeg \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /usr/share/espeak-ng-data \
|
||||
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||
|
||||
# Install UV using the installer script
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
|
|
|
@ -94,6 +94,19 @@
|
|||
<label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label>
|
||||
<input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0">
|
||||
</div>
|
||||
<div class="lang-control">
|
||||
<label for="lang-select">Language:</label>
|
||||
<select id="lang-select" class="lang-select">
|
||||
<option value="">Auto</option>
|
||||
<option value="e">Spanish</option>
|
||||
<option value="a">English</option>
|
||||
<option value="f">French</option>
|
||||
<option value="i">Italian</option>
|
||||
<option value="p">Portuguese</option>
|
||||
<option value="j">Japanese</option>
|
||||
<option value="z">Chinese</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="button-group">
|
||||
<button id="generate-btn">
|
||||
|
|
|
@ -42,7 +42,8 @@ export class AudioService {
|
|||
response_format: 'mp3',
|
||||
stream: true,
|
||||
speed: speed,
|
||||
return_download_link: true
|
||||
return_download_link: true,
|
||||
lang_code: document.getElementById('lang-select').value || undefined
|
||||
}),
|
||||
signal: this.controller.signal
|
||||
});
|
||||
|
|
|
@ -252,6 +252,49 @@
|
|||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
/* Language Control */
|
||||
.lang-control {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
padding: 0.75rem;
|
||||
background: rgba(15, 23, 42, 0.3);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.lang-control label {
|
||||
color: var(--text-light);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.lang-select {
|
||||
background: rgba(15, 23, 42, 0.3);
|
||||
color: var(--text);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem 1rem;
|
||||
font-family: var(--font-family);
|
||||
font-size: 0.875rem;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.lang-select:hover {
|
||||
border-color: var(--fg-color);
|
||||
}
|
||||
|
||||
.lang-select:focus {
|
||||
outline: none;
|
||||
border-color: var(--fg-color);
|
||||
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
|
||||
}
|
||||
|
||||
.lang-select option {
|
||||
background: var(--surface);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
/* Generation Controls */
|
||||
.button-group {
|
||||
display: flex;
|
||||
|
|
Loading…
Add table
Reference in a new issue