mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
refactor: streamline audio normalization process and update tests
This commit is contained in:
parent
d2522bcb92
commit
387653050b
9 changed files with 44 additions and 103 deletions
|
@ -6,6 +6,5 @@ select = ["I"]
|
|||
[lint.isort]
|
||||
combine-as-imports = true
|
||||
force-wrap-aliases = true
|
||||
length-sort = true
|
||||
split-on-trailing-comma = true
|
||||
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
||||
|
|
|
@ -22,20 +22,16 @@ class AudioNormalizer:
|
|||
def normalize(
|
||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Normalize audio data to int16 range and trim chunk boundaries"""
|
||||
# Convert to float32 if not already
|
||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||
# Simple float32 to int16 conversion
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
||||
# Normalize to [-1, 1] range first
|
||||
if np.max(np.abs(audio_float)) > 0:
|
||||
audio_float = audio_float / np.max(np.abs(audio_float))
|
||||
|
||||
# Trim end of non-final chunks to reduce gaps
|
||||
|
||||
# Trim for non-final chunks
|
||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||
audio_float = audio_float[: -self.samples_to_trim]
|
||||
|
||||
# Scale to int16 range
|
||||
return (audio_float * self.int16_max).astype(np.int16)
|
||||
audio_float = audio_float[:-self.samples_to_trim]
|
||||
|
||||
# Direct scaling like the non-streaming version
|
||||
return (audio_float * 32767).astype(np.int16)
|
||||
|
||||
|
||||
class AudioService:
|
||||
|
|
BIN
api/src/voices/af_irulan.pt
Normal file
BIN
api/src/voices/af_irulan.pt
Normal file
Binary file not shown.
|
@ -32,77 +32,7 @@ def cleanup():
|
|||
cleanup_mock_dirs()
|
||||
|
||||
|
||||
# Create mock torch module
|
||||
mock_torch = Mock()
|
||||
mock_torch.cuda = Mock()
|
||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
||||
|
||||
|
||||
# Create a mock tensor class that supports basic operations
|
||||
class MockTensor:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
if isinstance(data, (list, tuple)):
|
||||
self.shape = [len(data)]
|
||||
elif isinstance(data, MockTensor):
|
||||
self.shape = data.shape
|
||||
else:
|
||||
self.shape = getattr(data, "shape", [1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
if isinstance(idx, slice):
|
||||
return MockTensor(self.data[idx])
|
||||
return self.data[idx]
|
||||
return self
|
||||
|
||||
def max(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
max_val = max(self.data)
|
||||
return MockTensor(max_val)
|
||||
return 5 # Default for testing
|
||||
|
||||
def item(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return max(self.data)
|
||||
if isinstance(self.data, (int, float)):
|
||||
return self.data
|
||||
return 5 # Default for testing
|
||||
|
||||
def cuda(self):
|
||||
"""Support cuda conversion"""
|
||||
return self
|
||||
|
||||
def any(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return any(self.data)
|
||||
return False
|
||||
|
||||
def all(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return all(self.data)
|
||||
return True
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return self
|
||||
|
||||
def expand(self, *args):
|
||||
return self
|
||||
|
||||
def type_as(self, other):
|
||||
return self
|
||||
|
||||
|
||||
# Add tensor operations to mock torch
|
||||
mock_torch.tensor = lambda x: MockTensor(x)
|
||||
mock_torch.zeros = lambda *args: MockTensor(
|
||||
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
|
||||
)
|
||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
||||
|
||||
# Mock modules before they're imported
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["transformers"] = Mock()
|
||||
sys.modules["phonemizer"] = Mock()
|
||||
sys.modules["models"] = Mock()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,16 +27,30 @@ def test_get_device_error():
|
|||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_available(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA available"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = True
|
||||
# Mock CUDA as unavailable since we're using CPU PyTorch
|
||||
mock_cuda_available.return_value = False
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
|
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
|
|||
TTSBaseModel._instance = mock_model
|
||||
|
||||
voice_count = await TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cuda"
|
||||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
|
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
|
|||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_unavailable(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA unavailable"""
|
||||
TTSBaseModel._device = None
|
||||
|
@ -70,6 +86,17 @@ async def test_setup_cuda_unavailable(
|
|||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
|
|
|
@ -8,7 +8,7 @@ import requests
|
|||
import sounddevice as sd
|
||||
|
||||
|
||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
|
||||
"""Stream TTS audio and play it back in real-time"""
|
||||
|
||||
print("\nStarting TTS stream request...")
|
||||
|
|
|
@ -46,6 +46,7 @@ test = [
|
|||
"httpx==0.26.0",
|
||||
"pytest-asyncio==0.23.5",
|
||||
"gradio>=5",
|
||||
"openai>=1.59.6",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
# Core dependencies for testing
|
||||
fastapi==0.115.6
|
||||
uvicorn==0.34.0
|
||||
pydantic==2.10.4
|
||||
pydantic-settings==2.7.0
|
||||
python-dotenv==1.0.1
|
||||
sqlalchemy==2.0.27
|
||||
|
||||
# Testing
|
||||
pytest==8.0.0
|
||||
httpx==0.26.0
|
||||
pytest-asyncio==0.23.5
|
||||
pytest-cov==6.0.0
|
||||
gradio==4.19.2
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -802,6 +802,7 @@ gpu = [
|
|||
test = [
|
||||
{ name = "gradio" },
|
||||
{ name = "httpx" },
|
||||
{ name = "openai" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
|
@ -819,6 +820,7 @@ requires-dist = [
|
|||
{ name = "numpy", specifier = ">=1.26.0" },
|
||||
{ name = "onnxruntime", specifier = "==1.20.1" },
|
||||
{ name = "openai", specifier = ">=1.59.6" },
|
||||
{ name = "openai", marker = "extra == 'test'", specifier = ">=1.59.6" },
|
||||
{ name = "phonemizer", specifier = "==3.3.0" },
|
||||
{ name = "pydantic", specifier = "==2.10.4" },
|
||||
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
||||
|
|
Loading…
Add table
Reference in a new issue