mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
refactor: streamline audio normalization process and update tests
This commit is contained in:
parent
d2522bcb92
commit
387653050b
9 changed files with 44 additions and 103 deletions
|
@ -6,6 +6,5 @@ select = ["I"]
|
||||||
[lint.isort]
|
[lint.isort]
|
||||||
combine-as-imports = true
|
combine-as-imports = true
|
||||||
force-wrap-aliases = true
|
force-wrap-aliases = true
|
||||||
length-sort = true
|
|
||||||
split-on-trailing-comma = true
|
split-on-trailing-comma = true
|
||||||
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
||||||
|
|
|
@ -22,20 +22,16 @@ class AudioNormalizer:
|
||||||
def normalize(
|
def normalize(
|
||||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Normalize audio data to int16 range and trim chunk boundaries"""
|
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||||
# Convert to float32 if not already
|
# Simple float32 to int16 conversion
|
||||||
audio_float = audio_data.astype(np.float32)
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
|
||||||
# Normalize to [-1, 1] range first
|
# Trim for non-final chunks
|
||||||
if np.max(np.abs(audio_float)) > 0:
|
|
||||||
audio_float = audio_float / np.max(np.abs(audio_float))
|
|
||||||
|
|
||||||
# Trim end of non-final chunks to reduce gaps
|
|
||||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||||
audio_float = audio_float[: -self.samples_to_trim]
|
audio_float = audio_float[:-self.samples_to_trim]
|
||||||
|
|
||||||
# Scale to int16 range
|
# Direct scaling like the non-streaming version
|
||||||
return (audio_float * self.int16_max).astype(np.int16)
|
return (audio_float * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
|
|
BIN
api/src/voices/af_irulan.pt
Normal file
BIN
api/src/voices/af_irulan.pt
Normal file
Binary file not shown.
|
@ -32,77 +32,7 @@ def cleanup():
|
||||||
cleanup_mock_dirs()
|
cleanup_mock_dirs()
|
||||||
|
|
||||||
|
|
||||||
# Create mock torch module
|
|
||||||
mock_torch = Mock()
|
|
||||||
mock_torch.cuda = Mock()
|
|
||||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
|
||||||
|
|
||||||
|
|
||||||
# Create a mock tensor class that supports basic operations
|
|
||||||
class MockTensor:
|
|
||||||
def __init__(self, data):
|
|
||||||
self.data = data
|
|
||||||
if isinstance(data, (list, tuple)):
|
|
||||||
self.shape = [len(data)]
|
|
||||||
elif isinstance(data, MockTensor):
|
|
||||||
self.shape = data.shape
|
|
||||||
else:
|
|
||||||
self.shape = getattr(data, "shape", [1])
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
if isinstance(idx, slice):
|
|
||||||
return MockTensor(self.data[idx])
|
|
||||||
return self.data[idx]
|
|
||||||
return self
|
|
||||||
|
|
||||||
def max(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
max_val = max(self.data)
|
|
||||||
return MockTensor(max_val)
|
|
||||||
return 5 # Default for testing
|
|
||||||
|
|
||||||
def item(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return max(self.data)
|
|
||||||
if isinstance(self.data, (int, float)):
|
|
||||||
return self.data
|
|
||||||
return 5 # Default for testing
|
|
||||||
|
|
||||||
def cuda(self):
|
|
||||||
"""Support cuda conversion"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def any(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return any(self.data)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def all(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return all(self.data)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unsqueeze(self, dim):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def expand(self, *args):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def type_as(self, other):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# Add tensor operations to mock torch
|
|
||||||
mock_torch.tensor = lambda x: MockTensor(x)
|
|
||||||
mock_torch.zeros = lambda *args: MockTensor(
|
|
||||||
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
|
|
||||||
)
|
|
||||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
|
||||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
|
||||||
|
|
||||||
# Mock modules before they're imported
|
# Mock modules before they're imported
|
||||||
sys.modules["torch"] = mock_torch
|
|
||||||
sys.modules["transformers"] = Mock()
|
sys.modules["transformers"] = Mock()
|
||||||
sys.modules["phonemizer"] = Mock()
|
sys.modules["phonemizer"] = Mock()
|
||||||
sys.modules["models"] = Mock()
|
sys.modules["models"] = Mock()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""Tests for TTS model implementations"""
|
"""Tests for TTS model implementations"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,16 +27,30 @@ def test_get_device_error():
|
||||||
@patch("os.listdir")
|
@patch("os.listdir")
|
||||||
@patch("torch.load")
|
@patch("torch.load")
|
||||||
@patch("torch.save")
|
@patch("torch.save")
|
||||||
|
@patch("api.src.services.tts_base.settings")
|
||||||
|
@patch("api.src.services.warmup.WarmupService")
|
||||||
async def test_setup_cuda_available(
|
async def test_setup_cuda_available(
|
||||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||||
):
|
):
|
||||||
"""Test setup with CUDA available"""
|
"""Test setup with CUDA available"""
|
||||||
TTSBaseModel._device = None
|
TTSBaseModel._device = None
|
||||||
mock_cuda_available.return_value = True
|
# Mock CUDA as unavailable since we're using CPU PyTorch
|
||||||
|
mock_cuda_available.return_value = False
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_load.return_value = torch.zeros(1)
|
mock_load.return_value = torch.zeros(1)
|
||||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||||
mock_join.return_value = "/mocked/path"
|
mock_join.return_value = "/mocked/path"
|
||||||
|
|
||||||
|
# Configure mock settings
|
||||||
|
mock_settings.model_dir = "/mock/model/dir"
|
||||||
|
mock_settings.onnx_model_path = "model.onnx"
|
||||||
|
mock_settings.voices_dir = "voices"
|
||||||
|
|
||||||
|
# Configure mock warmup service
|
||||||
|
mock_warmup = MagicMock()
|
||||||
|
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||||
|
mock_warmup.warmup_voices = AsyncMock()
|
||||||
|
mock_warmup_class.return_value = mock_warmup
|
||||||
|
|
||||||
# Create mock model
|
# Create mock model
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
|
||||||
TTSBaseModel._instance = mock_model
|
TTSBaseModel._instance = mock_model
|
||||||
|
|
||||||
voice_count = await TTSBaseModel.setup()
|
voice_count = await TTSBaseModel.setup()
|
||||||
assert TTSBaseModel._device == "cuda"
|
assert TTSBaseModel._device == "cpu"
|
||||||
assert voice_count == 2
|
assert voice_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
|
||||||
@patch("os.listdir")
|
@patch("os.listdir")
|
||||||
@patch("torch.load")
|
@patch("torch.load")
|
||||||
@patch("torch.save")
|
@patch("torch.save")
|
||||||
|
@patch("api.src.services.tts_base.settings")
|
||||||
|
@patch("api.src.services.warmup.WarmupService")
|
||||||
async def test_setup_cuda_unavailable(
|
async def test_setup_cuda_unavailable(
|
||||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||||
):
|
):
|
||||||
"""Test setup with CUDA unavailable"""
|
"""Test setup with CUDA unavailable"""
|
||||||
TTSBaseModel._device = None
|
TTSBaseModel._device = None
|
||||||
|
@ -70,6 +86,17 @@ async def test_setup_cuda_unavailable(
|
||||||
mock_load.return_value = torch.zeros(1)
|
mock_load.return_value = torch.zeros(1)
|
||||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||||
mock_join.return_value = "/mocked/path"
|
mock_join.return_value = "/mocked/path"
|
||||||
|
|
||||||
|
# Configure mock settings
|
||||||
|
mock_settings.model_dir = "/mock/model/dir"
|
||||||
|
mock_settings.onnx_model_path = "model.onnx"
|
||||||
|
mock_settings.voices_dir = "voices"
|
||||||
|
|
||||||
|
# Configure mock warmup service
|
||||||
|
mock_warmup = MagicMock()
|
||||||
|
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||||
|
mock_warmup.warmup_voices = AsyncMock()
|
||||||
|
mock_warmup_class.return_value = mock_warmup
|
||||||
|
|
||||||
# Create mock model
|
# Create mock model
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
|
|
@ -8,7 +8,7 @@ import requests
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
|
|
||||||
|
|
||||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
|
||||||
"""Stream TTS audio and play it back in real-time"""
|
"""Stream TTS audio and play it back in real-time"""
|
||||||
|
|
||||||
print("\nStarting TTS stream request...")
|
print("\nStarting TTS stream request...")
|
||||||
|
|
|
@ -46,6 +46,7 @@ test = [
|
||||||
"httpx==0.26.0",
|
"httpx==0.26.0",
|
||||||
"pytest-asyncio==0.23.5",
|
"pytest-asyncio==0.23.5",
|
||||||
"gradio>=5",
|
"gradio>=5",
|
||||||
|
"openai>=1.59.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
# Core dependencies for testing
|
|
||||||
fastapi==0.115.6
|
|
||||||
uvicorn==0.34.0
|
|
||||||
pydantic==2.10.4
|
|
||||||
pydantic-settings==2.7.0
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
sqlalchemy==2.0.27
|
|
||||||
|
|
||||||
# Testing
|
|
||||||
pytest==8.0.0
|
|
||||||
httpx==0.26.0
|
|
||||||
pytest-asyncio==0.23.5
|
|
||||||
pytest-cov==6.0.0
|
|
||||||
gradio==4.19.2
|
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -802,6 +802,7 @@ gpu = [
|
||||||
test = [
|
test = [
|
||||||
{ name = "gradio" },
|
{ name = "gradio" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
{ name = "openai" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
|
@ -819,6 +820,7 @@ requires-dist = [
|
||||||
{ name = "numpy", specifier = ">=1.26.0" },
|
{ name = "numpy", specifier = ">=1.26.0" },
|
||||||
{ name = "onnxruntime", specifier = "==1.20.1" },
|
{ name = "onnxruntime", specifier = "==1.20.1" },
|
||||||
{ name = "openai", specifier = ">=1.59.6" },
|
{ name = "openai", specifier = ">=1.59.6" },
|
||||||
|
{ name = "openai", marker = "extra == 'test'", specifier = ">=1.59.6" },
|
||||||
{ name = "phonemizer", specifier = "==3.3.0" },
|
{ name = "phonemizer", specifier = "==3.3.0" },
|
||||||
{ name = "pydantic", specifier = "==2.10.4" },
|
{ name = "pydantic", specifier = "==2.10.4" },
|
||||||
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
||||||
|
|
Loading…
Add table
Reference in a new issue