mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
-fix voice selection not matching language phonemes
-added voice language override parameter
This commit is contained in:
parent
68cc14896a
commit
a0dc870f4a
12 changed files with 259 additions and 51 deletions
|
@ -3,8 +3,8 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||||
|
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""Clean Kokoro implementation with controlled resource management."""
|
"""Clean Kokoro implementation with controlled resource management."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, Optional, Union, Tuple
|
from typing import AsyncGenerator, Optional, Union, Tuple, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -22,7 +22,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
# Strictly respect settings.use_gpu
|
# Strictly respect settings.use_gpu
|
||||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||||
self._model: Optional[KModel] = None
|
self._model: Optional[KModel] = None
|
||||||
self._pipeline: Optional[KPipeline] = None
|
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load pre-baked model.
|
"""Load pre-baked model.
|
||||||
|
@ -54,22 +54,38 @@ class KokoroV1(BaseModelBackend):
|
||||||
if self._device == "cuda":
|
if self._device == "cuda":
|
||||||
self._model = self._model.cuda()
|
self._model = self._model.cuda()
|
||||||
|
|
||||||
# Initialize pipeline with our model and device
|
|
||||||
self._pipeline = KPipeline(
|
|
||||||
lang_code='a',
|
|
||||||
model=self._model, # Pass our model directly
|
|
||||||
device=self._device # Match our device setting
|
|
||||||
)
|
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
||||||
|
|
||||||
|
def _get_pipeline(self, lang_code: str) -> KPipeline:
|
||||||
|
"""Get or create pipeline for language code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang_code: Language code to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KPipeline instance for the language
|
||||||
|
"""
|
||||||
|
if not self._model:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
if lang_code not in self._pipelines:
|
||||||
|
logger.info(f"Creating new pipeline for language code: {lang_code}")
|
||||||
|
self._pipelines[lang_code] = KPipeline(
|
||||||
|
lang_code=lang_code,
|
||||||
|
model=self._model,
|
||||||
|
device=self._device
|
||||||
|
)
|
||||||
|
return self._pipelines[lang_code]
|
||||||
|
|
||||||
async def generate_from_tokens(
|
async def generate_from_tokens(
|
||||||
self,
|
self,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||||
speed: float = 1.0
|
speed: float = 1.0,
|
||||||
|
lang_code: Optional[str] = None
|
||||||
) -> AsyncGenerator[np.ndarray, None]:
|
) -> AsyncGenerator[np.ndarray, None]:
|
||||||
"""Generate audio from phoneme tokens.
|
"""Generate audio from phoneme tokens.
|
||||||
|
|
||||||
|
@ -77,6 +93,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
tokens: Input phoneme tokens to synthesize
|
tokens: Input phoneme tokens to synthesize
|
||||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
lang_code: Optional language code override
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generated audio chunks
|
Generated audio chunks
|
||||||
|
@ -95,6 +112,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
|
|
||||||
# Handle voice input
|
# Handle voice input
|
||||||
voice_path: str
|
voice_path: str
|
||||||
|
voice_name: str
|
||||||
if isinstance(voice, tuple):
|
if isinstance(voice, tuple):
|
||||||
voice_name, voice_data = voice
|
voice_name, voice_data = voice
|
||||||
if isinstance(voice_data, str):
|
if isinstance(voice_data, str):
|
||||||
|
@ -108,6 +126,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
torch.save(voice_data.cpu(), voice_path)
|
torch.save(voice_data.cpu(), voice_path)
|
||||||
else:
|
else:
|
||||||
voice_path = voice
|
voice_path = voice
|
||||||
|
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
||||||
|
|
||||||
# Load voice tensor with proper device mapping
|
# Load voice tensor with proper device mapping
|
||||||
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
||||||
|
@ -118,9 +137,12 @@ class KokoroV1(BaseModelBackend):
|
||||||
await paths.save_voice_tensor(voice_tensor, temp_path)
|
await paths.save_voice_tensor(voice_tensor, temp_path)
|
||||||
voice_path = temp_path
|
voice_path = temp_path
|
||||||
|
|
||||||
# Generate using pipeline's generate_from_tokens method
|
# Use provided lang_code or get from voice name
|
||||||
logger.debug(f"Generating audio from tokens: '{tokens[:100]}...'")
|
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
|
||||||
for result in self._pipeline.generate_from_tokens(
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||||
|
|
||||||
|
logger.debug(f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'")
|
||||||
|
for result in pipeline.generate_from_tokens(
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
voice=voice_path,
|
voice=voice_path,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
|
@ -140,7 +162,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
and "out of memory" in str(e).lower()
|
and "out of memory" in str(e).lower()
|
||||||
):
|
):
|
||||||
self._clear_memory()
|
self._clear_memory()
|
||||||
async for chunk in self.generate_from_tokens(tokens, voice, speed):
|
async for chunk in self.generate_from_tokens(tokens, voice, speed, lang_code):
|
||||||
yield chunk
|
yield chunk
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -148,7 +170,8 @@ class KokoroV1(BaseModelBackend):
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||||
speed: float = 1.0
|
speed: float = 1.0,
|
||||||
|
lang_code: Optional[str] = None
|
||||||
) -> AsyncGenerator[np.ndarray, None]:
|
) -> AsyncGenerator[np.ndarray, None]:
|
||||||
"""Generate audio using model.
|
"""Generate audio using model.
|
||||||
|
|
||||||
|
@ -156,6 +179,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
text: Input text to synthesize
|
text: Input text to synthesize
|
||||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
lang_code: Optional language code override
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generated audio chunks
|
Generated audio chunks
|
||||||
|
@ -174,6 +198,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
|
|
||||||
# Handle voice input
|
# Handle voice input
|
||||||
voice_path: str
|
voice_path: str
|
||||||
|
voice_name: str
|
||||||
if isinstance(voice, tuple):
|
if isinstance(voice, tuple):
|
||||||
voice_name, voice_data = voice
|
voice_name, voice_data = voice
|
||||||
if isinstance(voice_data, str):
|
if isinstance(voice_data, str):
|
||||||
|
@ -187,6 +212,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
torch.save(voice_data.cpu(), voice_path)
|
torch.save(voice_data.cpu(), voice_path)
|
||||||
else:
|
else:
|
||||||
voice_path = voice
|
voice_path = voice
|
||||||
|
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
||||||
|
|
||||||
# Load voice tensor with proper device mapping
|
# Load voice tensor with proper device mapping
|
||||||
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
||||||
|
@ -197,9 +223,17 @@ class KokoroV1(BaseModelBackend):
|
||||||
await paths.save_voice_tensor(voice_tensor, temp_path)
|
await paths.save_voice_tensor(voice_tensor, temp_path)
|
||||||
voice_path = temp_path
|
voice_path = temp_path
|
||||||
|
|
||||||
# Generate using pipeline, force model to prevent downloads
|
# Use provided lang_code or get from voice name
|
||||||
logger.debug(f"Generating audio for text: '{text[:100]}...'")
|
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
|
||||||
for result in self._pipeline(text, voice=voice_path, speed=speed, model=self._model):
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||||
|
|
||||||
|
logger.debug(f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'")
|
||||||
|
for result in pipeline(
|
||||||
|
text,
|
||||||
|
voice=voice_path,
|
||||||
|
speed=speed,
|
||||||
|
model=self._model
|
||||||
|
):
|
||||||
if result.audio is not None:
|
if result.audio is not None:
|
||||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
yield result.audio.numpy()
|
yield result.audio.numpy()
|
||||||
|
@ -214,7 +248,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
and "out of memory" in str(e).lower()
|
and "out of memory" in str(e).lower()
|
||||||
):
|
):
|
||||||
self._clear_memory()
|
self._clear_memory()
|
||||||
async for chunk in self.generate(text, voice, speed):
|
async for chunk in self.generate(text, voice, speed, lang_code):
|
||||||
yield chunk
|
yield chunk
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -236,9 +270,9 @@ class KokoroV1(BaseModelBackend):
|
||||||
if self._model is not None:
|
if self._model is not None:
|
||||||
del self._model
|
del self._model
|
||||||
self._model = None
|
self._model = None
|
||||||
if self._pipeline is not None:
|
for pipeline in self._pipelines.values():
|
||||||
del self._pipeline
|
del pipeline
|
||||||
self._pipeline = None
|
self._pipelines.clear()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
@ -246,7 +280,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
@property
|
@property
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""Check if model is loaded."""
|
"""Check if model is loaded."""
|
||||||
return self._model is not None and self._pipeline is not None
|
return self._model is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> str:
|
def device(self) -> str:
|
||||||
|
|
|
@ -72,7 +72,10 @@ class ModelManager:
|
||||||
|
|
||||||
# Warm up with short text
|
# Warm up with short text
|
||||||
warmup_text = "Warmup text for initialization."
|
warmup_text = "Warmup text for initialization."
|
||||||
async for _ in self.generate(warmup_text, voice_path):
|
# Use default voice name for warmup
|
||||||
|
voice_name = settings.default_voice
|
||||||
|
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
||||||
|
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to get default voice: {e}")
|
raise RuntimeError(f"Failed to get default voice: {e}")
|
||||||
|
|
|
@ -28,6 +28,7 @@ def setup_logger():
|
||||||
"sink": sys.stdout,
|
"sink": sys.stdout,
|
||||||
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
||||||
"{level: <8} | "
|
"{level: <8} | "
|
||||||
|
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
|
||||||
"{message}",
|
"{message}",
|
||||||
"colorize": True,
|
"colorize": True,
|
||||||
"level": "DEBUG",
|
"level": "DEBUG",
|
||||||
|
@ -88,6 +89,7 @@ async def lifespan(app: FastAPI):
|
||||||
# Add web player info if enabled
|
# Add web player info if enabled
|
||||||
if settings.enable_web_player:
|
if settings.enable_web_player:
|
||||||
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
||||||
|
startup_msg += f"\nor http://localhost:{settings.port}/web/"
|
||||||
else:
|
else:
|
||||||
startup_msg += "\n\nWeb Player: disabled"
|
startup_msg += "\n\nWeb Player: disabled"
|
||||||
|
|
||||||
|
|
|
@ -130,11 +130,13 @@ async def stream_audio_chunks(
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
|
lang_code=request.lang_code,
|
||||||
):
|
):
|
||||||
# Check if client is still connected
|
# Check if client is still connected
|
||||||
is_disconnected = client_request.is_disconnected
|
is_disconnected = client_request.is_disconnected
|
||||||
|
@ -250,7 +252,8 @@ async def create_speech(
|
||||||
audio, _ = await tts_service.generate_audio(
|
audio, _ = await tts_service.generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
speed=request.speed
|
speed=request.speed,
|
||||||
|
lang_code=request.lang_code
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to requested format with proper finalization
|
# Convert to requested format with proper finalization
|
||||||
|
|
|
@ -51,6 +51,7 @@ class TTSService:
|
||||||
is_first: bool = False,
|
is_first: bool = False,
|
||||||
is_last: bool = False,
|
is_last: bool = False,
|
||||||
normalizer: Optional[AudioNormalizer] = None,
|
normalizer: Optional[AudioNormalizer] = None,
|
||||||
|
lang_code: Optional[str] = None,
|
||||||
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
||||||
"""Process tokens into audio."""
|
"""Process tokens into audio."""
|
||||||
async with self._chunk_semaphore:
|
async with self._chunk_semaphore:
|
||||||
|
@ -82,11 +83,12 @@ class TTSService:
|
||||||
|
|
||||||
# Generate audio using pre-warmed model
|
# Generate audio using pre-warmed model
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# For Kokoro V1, pass text and voice info
|
# For Kokoro V1, pass text and voice info with lang_code
|
||||||
async for chunk_audio in self.model_manager.generate(
|
async for chunk_audio in self.model_manager.generate(
|
||||||
chunk_text,
|
chunk_text,
|
||||||
(voice_name, voice_path),
|
(voice_name, voice_path),
|
||||||
speed=speed
|
speed=speed,
|
||||||
|
lang_code=lang_code
|
||||||
):
|
):
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
|
@ -217,6 +219,7 @@ class TTSService:
|
||||||
voice: str,
|
voice: str,
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
output_format: str = "wav",
|
output_format: str = "wav",
|
||||||
|
lang_code: Optional[str] = None,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Generate and stream audio chunks."""
|
"""Generate and stream audio chunks."""
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
|
@ -230,6 +233,10 @@ class TTSService:
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
logger.debug(f"Using voice path: {voice_path}")
|
logger.debug(f"Using voice path: {voice_path}")
|
||||||
|
|
||||||
|
# Use provided lang_code or determine from voice name
|
||||||
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
|
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting
|
||||||
async for chunk_text, tokens in smart_split(text):
|
async for chunk_text, tokens in smart_split(text):
|
||||||
try:
|
try:
|
||||||
|
@ -243,7 +250,8 @@ class TTSService:
|
||||||
output_format,
|
output_format,
|
||||||
is_first=(chunk_index == 0),
|
is_first=(chunk_index == 0),
|
||||||
is_last=False, # We'll update the last chunk later
|
is_last=False, # We'll update the last chunk later
|
||||||
normalizer=stream_normalizer
|
normalizer=stream_normalizer,
|
||||||
|
lang_code=pipeline_lang_code # Pass lang_code
|
||||||
):
|
):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result
|
||||||
|
@ -268,7 +276,8 @@ class TTSService:
|
||||||
output_format,
|
output_format,
|
||||||
is_first=False,
|
is_first=False,
|
||||||
is_last=True, # Signal this is the last chunk
|
is_last=True, # Signal this is the last chunk
|
||||||
normalizer=stream_normalizer
|
normalizer=stream_normalizer,
|
||||||
|
lang_code=pipeline_lang_code # Pass lang_code
|
||||||
):
|
):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result
|
||||||
|
@ -280,7 +289,8 @@ class TTSService:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
|
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False,
|
||||||
|
lang_code: Optional[str] = None
|
||||||
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
||||||
"""Generate complete audio for text using streaming internally."""
|
"""Generate complete audio for text using streaming internally."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -293,8 +303,12 @@ class TTSService:
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
|
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
|
# Use provided lang_code or determine from voice name
|
||||||
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
|
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking")
|
||||||
|
|
||||||
# Initialize quiet pipeline for text chunking
|
# Initialize quiet pipeline for text chunking
|
||||||
quiet_pipeline = KPipeline(lang_code='a', model=False)
|
quiet_pipeline = KPipeline(lang_code=pipeline_lang_code, model=False)
|
||||||
|
|
||||||
# Split text into chunks and get initial tokens
|
# Split text into chunks and get initial tokens
|
||||||
text_chunks = []
|
text_chunks = []
|
||||||
|
@ -310,12 +324,13 @@ class TTSService:
|
||||||
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
|
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
|
||||||
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
|
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
|
||||||
|
|
||||||
# Generate audio and timestamps for this chunk
|
# Create a new pipeline with the lang_code
|
||||||
for result in backend._pipeline(
|
generation_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
|
||||||
|
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in generation pipeline")
|
||||||
|
for result in generation_pipeline(
|
||||||
chunk_text,
|
chunk_text,
|
||||||
voice=voice_path,
|
voice=voice_path,
|
||||||
speed=speed,
|
speed=speed
|
||||||
model=backend._model
|
|
||||||
):
|
):
|
||||||
# Collect audio chunks
|
# Collect audio chunks
|
||||||
if result.audio is not None:
|
if result.audio is not None:
|
||||||
|
@ -470,7 +485,8 @@ class TTSService:
|
||||||
self,
|
self,
|
||||||
phonemes: str,
|
phonemes: str,
|
||||||
voice: str,
|
voice: str,
|
||||||
speed: float = 1.0
|
speed: float = 1.0,
|
||||||
|
lang_code: Optional[str] = None
|
||||||
) -> Tuple[np.ndarray, float]:
|
) -> Tuple[np.ndarray, float]:
|
||||||
"""Generate audio directly from phonemes.
|
"""Generate audio directly from phonemes.
|
||||||
|
|
||||||
|
@ -478,6 +494,7 @@ class TTSService:
|
||||||
phonemes: Phonemes in Kokoro format
|
phonemes: Phonemes in Kokoro format
|
||||||
voice: Voice name
|
voice: Voice name
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
lang_code: Optional language code override
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (audio array, processing time)
|
Tuple of (audio array, processing time)
|
||||||
|
@ -491,11 +508,16 @@ class TTSService:
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||||
result = None
|
result = None
|
||||||
for r in backend._pipeline.generate_from_tokens(
|
# Use provided lang_code or determine from voice name
|
||||||
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
|
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
|
||||||
|
|
||||||
|
# Create a new pipeline with the lang_code
|
||||||
|
phoneme_pipeline = KPipeline(lang_code=pipeline_lang_code, model=backend._model)
|
||||||
|
for r in phoneme_pipeline.generate_from_tokens(
|
||||||
tokens=phonemes, # Pass raw phonemes string
|
tokens=phonemes, # Pass raw phonemes string
|
||||||
voice=voice_path,
|
voice=voice_path,
|
||||||
speed=speed,
|
speed=speed
|
||||||
model=backend._model
|
|
||||||
):
|
):
|
||||||
if r.audio is not None:
|
if r.audio is not None:
|
||||||
result = r
|
result = r
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Union
|
from typing import List, Literal, Union, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -62,6 +62,10 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default=False,
|
default=False,
|
||||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||||
)
|
)
|
||||||
|
lang_code: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
|
)
|
||||||
|
|
||||||
class CaptionedSpeechRequest(BaseModel):
|
class CaptionedSpeechRequest(BaseModel):
|
||||||
"""Request schema for captioned speech endpoint"""
|
"""Request schema for captioned speech endpoint"""
|
||||||
|
@ -88,3 +92,7 @@ class CaptionedSpeechRequest(BaseModel):
|
||||||
default=True,
|
default=True,
|
||||||
description="If true (default), returns word-level timestamps in the response",
|
description="If true (default), returns word-level timestamps in the response",
|
||||||
)
|
)
|
||||||
|
lang_code: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock, ANY
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from api.src.inference.kokoro_v1 import KokoroV1
|
from api.src.inference.kokoro_v1 import KokoroV1
|
||||||
|
@ -13,7 +13,7 @@ def test_initial_state(kokoro_backend):
|
||||||
"""Test initial state of KokoroV1."""
|
"""Test initial state of KokoroV1."""
|
||||||
assert not kokoro_backend.is_loaded
|
assert not kokoro_backend.is_loaded
|
||||||
assert kokoro_backend._model is None
|
assert kokoro_backend._model is None
|
||||||
assert kokoro_backend._pipeline is None
|
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
|
||||||
# Device should be set based on settings
|
# Device should be set based on settings
|
||||||
assert kokoro_backend.device in ["cuda", "cpu"]
|
assert kokoro_backend.device in ["cuda", "cpu"]
|
||||||
|
|
||||||
|
@ -47,18 +47,23 @@ async def test_load_model_validation(kokoro_backend):
|
||||||
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
||||||
await kokoro_backend.load_model("nonexistent_model.pth")
|
await kokoro_backend.load_model("nonexistent_model.pth")
|
||||||
|
|
||||||
def test_unload(kokoro_backend):
|
def test_unload_with_pipelines(kokoro_backend):
|
||||||
"""Test model unloading."""
|
"""Test model unloading with multiple pipelines."""
|
||||||
# Mock loaded state
|
# Mock loaded state with multiple pipelines
|
||||||
kokoro_backend._model = MagicMock()
|
kokoro_backend._model = MagicMock()
|
||||||
kokoro_backend._pipeline = MagicMock()
|
pipeline_a = MagicMock()
|
||||||
|
pipeline_e = MagicMock()
|
||||||
|
kokoro_backend._pipelines = {
|
||||||
|
'a': pipeline_a,
|
||||||
|
'e': pipeline_e
|
||||||
|
}
|
||||||
assert kokoro_backend.is_loaded
|
assert kokoro_backend.is_loaded
|
||||||
|
|
||||||
# Test unload
|
# Test unload
|
||||||
kokoro_backend.unload()
|
kokoro_backend.unload()
|
||||||
assert not kokoro_backend.is_loaded
|
assert not kokoro_backend.is_loaded
|
||||||
assert kokoro_backend._model is None
|
assert kokoro_backend._model is None
|
||||||
assert kokoro_backend._pipeline is None
|
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_validation(kokoro_backend):
|
async def test_generate_validation(kokoro_backend):
|
||||||
|
@ -73,3 +78,77 @@ async def test_generate_from_tokens_validation(kokoro_backend):
|
||||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||||
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_get_pipeline_creates_new(kokoro_backend):
|
||||||
|
"""Test that _get_pipeline creates new pipeline for new language code."""
|
||||||
|
# Mock loaded state
|
||||||
|
kokoro_backend._model = MagicMock()
|
||||||
|
|
||||||
|
# Mock KPipeline
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
|
||||||
|
# Get pipeline for Spanish
|
||||||
|
pipeline_e = kokoro_backend._get_pipeline('e')
|
||||||
|
|
||||||
|
# Should create new pipeline with correct params
|
||||||
|
mock_kpipeline.assert_called_once_with(
|
||||||
|
lang_code='e',
|
||||||
|
model=kokoro_backend._model,
|
||||||
|
device=kokoro_backend._device
|
||||||
|
)
|
||||||
|
assert pipeline_e == mock_pipeline
|
||||||
|
assert kokoro_backend._pipelines['e'] == mock_pipeline
|
||||||
|
|
||||||
|
def test_get_pipeline_reuses_existing(kokoro_backend):
|
||||||
|
"""Test that _get_pipeline reuses existing pipeline for same language code."""
|
||||||
|
# Mock loaded state
|
||||||
|
kokoro_backend._model = MagicMock()
|
||||||
|
|
||||||
|
# Mock KPipeline
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline) as mock_kpipeline:
|
||||||
|
# Get pipeline twice for same language
|
||||||
|
pipeline1 = kokoro_backend._get_pipeline('e')
|
||||||
|
pipeline2 = kokoro_backend._get_pipeline('e')
|
||||||
|
|
||||||
|
# Should only create pipeline once
|
||||||
|
mock_kpipeline.assert_called_once()
|
||||||
|
assert pipeline1 == pipeline2
|
||||||
|
assert kokoro_backend._pipelines['e'] == mock_pipeline
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_uses_correct_pipeline(kokoro_backend):
|
||||||
|
"""Test that generate uses correct pipeline for language code."""
|
||||||
|
# Mock loaded state
|
||||||
|
kokoro_backend._model = MagicMock()
|
||||||
|
|
||||||
|
# Mock voice path handling
|
||||||
|
with patch('api.src.core.paths.load_voice_tensor') as mock_load_voice, \
|
||||||
|
patch('api.src.core.paths.save_voice_tensor'), \
|
||||||
|
patch('tempfile.gettempdir') as mock_tempdir:
|
||||||
|
mock_load_voice.return_value = torch.ones(1)
|
||||||
|
mock_tempdir.return_value = "/tmp"
|
||||||
|
|
||||||
|
# Mock KPipeline
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
mock_pipeline.return_value = iter([]) # Empty generator for testing
|
||||||
|
with patch('api.src.inference.kokoro_v1.KPipeline', return_value=mock_pipeline):
|
||||||
|
# Generate with Spanish voice and explicit lang_code
|
||||||
|
async for _ in kokoro_backend.generate(
|
||||||
|
"test", "ef_voice", lang_code='e'
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should create pipeline with Spanish lang_code
|
||||||
|
assert 'e' in kokoro_backend._pipelines
|
||||||
|
# Use ANY to match the temp file path since it's dynamic
|
||||||
|
mock_pipeline.assert_called_with(
|
||||||
|
"test",
|
||||||
|
voice=ANY, # Don't check exact path since it's dynamic
|
||||||
|
speed=1.0,
|
||||||
|
model=kokoro_backend._model
|
||||||
|
)
|
||||||
|
# Verify the voice path is a temp file path
|
||||||
|
call_args = mock_pipeline.call_args
|
||||||
|
assert isinstance(call_args[1]['voice'], str)
|
||||||
|
assert call_args[1]['voice'].startswith("/tmp/temp_voice_")
|
|
@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||||
&& mkdir -p /usr/share/espeak-ng-data \
|
&& mkdir -p /usr/share/espeak-ng-data \
|
||||||
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
|
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||||
|
|
||||||
# Install UV using the installer script
|
# Install UV using the installer script
|
||||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||||
|
|
|
@ -94,6 +94,19 @@
|
||||||
<label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label>
|
<label for="speed-slider">Speed: <span id="speed-value">1.0</span>x</label>
|
||||||
<input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0">
|
<input type="range" id="speed-slider" min="0.1" max="4" step="0.1" value="1.0">
|
||||||
</div>
|
</div>
|
||||||
|
<div class="lang-control">
|
||||||
|
<label for="lang-select">Language:</label>
|
||||||
|
<select id="lang-select" class="lang-select">
|
||||||
|
<option value="">Auto</option>
|
||||||
|
<option value="e">Spanish</option>
|
||||||
|
<option value="a">English</option>
|
||||||
|
<option value="f">French</option>
|
||||||
|
<option value="i">Italian</option>
|
||||||
|
<option value="p">Portuguese</option>
|
||||||
|
<option value="j">Japanese</option>
|
||||||
|
<option value="z">Chinese</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="button-group">
|
<div class="button-group">
|
||||||
<button id="generate-btn">
|
<button id="generate-btn">
|
||||||
|
|
|
@ -42,7 +42,8 @@ export class AudioService {
|
||||||
response_format: 'mp3',
|
response_format: 'mp3',
|
||||||
stream: true,
|
stream: true,
|
||||||
speed: speed,
|
speed: speed,
|
||||||
return_download_link: true
|
return_download_link: true,
|
||||||
|
lang_code: document.getElementById('lang-select').value || undefined
|
||||||
}),
|
}),
|
||||||
signal: this.controller.signal
|
signal: this.controller.signal
|
||||||
});
|
});
|
||||||
|
|
|
@ -252,6 +252,49 @@
|
||||||
transform: scale(1.1);
|
transform: scale(1.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Language Control */
|
||||||
|
.lang-control {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.75rem;
|
||||||
|
padding: 0.75rem;
|
||||||
|
background: rgba(15, 23, 42, 0.3);
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lang-control label {
|
||||||
|
color: var(--text-light);
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lang-select {
|
||||||
|
background: rgba(15, 23, 42, 0.3);
|
||||||
|
color: var(--text);
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
font-family: var(--font-family);
|
||||||
|
font-size: 0.875rem;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lang-select:hover {
|
||||||
|
border-color: var(--fg-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lang-select:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: var(--fg-color);
|
||||||
|
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lang-select option {
|
||||||
|
background: var(--surface);
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
/* Generation Controls */
|
/* Generation Controls */
|
||||||
.button-group {
|
.button-group {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
|
Loading…
Add table
Reference in a new issue