refactor: streamline audio normalization process and update tests

This commit is contained in:
remsky 2025-01-13 18:56:49 -07:00
parent d2522bcb92
commit 387653050b
9 changed files with 44 additions and 103 deletions

View file

@ -6,6 +6,5 @@ select = ["I"]
[lint.isort]
combine-as-imports = true
force-wrap-aliases = true
length-sort = true
split-on-trailing-comma = true
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]

View file

@ -22,20 +22,16 @@ class AudioNormalizer:
def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Normalize audio data to int16 range and trim chunk boundaries"""
# Convert to float32 if not already
"""Convert audio data to int16 range and trim chunk boundaries"""
# Simple float32 to int16 conversion
audio_float = audio_data.astype(np.float32)
# Normalize to [-1, 1] range first
if np.max(np.abs(audio_float)) > 0:
audio_float = audio_float / np.max(np.abs(audio_float))
# Trim end of non-final chunks to reduce gaps
# Trim for non-final chunks
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[: -self.samples_to_trim]
# Scale to int16 range
return (audio_float * self.int16_max).astype(np.int16)
audio_float = audio_float[:-self.samples_to_trim]
# Direct scaling like the non-streaming version
return (audio_float * 32767).astype(np.int16)
class AudioService:

BIN
api/src/voices/af_irulan.pt Normal file

Binary file not shown.

View file

@ -32,77 +32,7 @@ def cleanup():
cleanup_mock_dirs()
# Create mock torch module
mock_torch = Mock()
mock_torch.cuda = Mock()
mock_torch.cuda.is_available = Mock(return_value=False)
# Create a mock tensor class that supports basic operations
class MockTensor:
def __init__(self, data):
self.data = data
if isinstance(data, (list, tuple)):
self.shape = [len(data)]
elif isinstance(data, MockTensor):
self.shape = data.shape
else:
self.shape = getattr(data, "shape", [1])
def __getitem__(self, idx):
if isinstance(self.data, (list, tuple)):
if isinstance(idx, slice):
return MockTensor(self.data[idx])
return self.data[idx]
return self
def max(self):
if isinstance(self.data, (list, tuple)):
max_val = max(self.data)
return MockTensor(max_val)
return 5 # Default for testing
def item(self):
if isinstance(self.data, (list, tuple)):
return max(self.data)
if isinstance(self.data, (int, float)):
return self.data
return 5 # Default for testing
def cuda(self):
"""Support cuda conversion"""
return self
def any(self):
if isinstance(self.data, (list, tuple)):
return any(self.data)
return False
def all(self):
if isinstance(self.data, (list, tuple)):
return all(self.data)
return True
def unsqueeze(self, dim):
return self
def expand(self, *args):
return self
def type_as(self, other):
return self
# Add tensor operations to mock torch
mock_torch.tensor = lambda x: MockTensor(x)
mock_torch.zeros = lambda *args: MockTensor(
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
)
mock_torch.arange = lambda x: MockTensor(list(range(x)))
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
# Mock modules before they're imported
sys.modules["torch"] = mock_torch
sys.modules["transformers"] = Mock()
sys.modules["phonemizer"] = Mock()
sys.modules["models"] = Mock()

View file

@ -1,7 +1,7 @@
"""Tests for TTS model implementations"""
import os
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock
import numpy as np
import torch
@ -27,16 +27,30 @@ def test_get_device_error():
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_available(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA available"""
TTSBaseModel._device = None
mock_cuda_available.return_value = True
# Mock CUDA as unavailable since we're using CPU PyTorch
mock_cuda_available.return_value = False
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
TTSBaseModel._instance = mock_model
voice_count = await TTSBaseModel.setup()
assert TTSBaseModel._device == "cuda"
assert TTSBaseModel._device == "cpu"
assert voice_count == 2
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_unavailable(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA unavailable"""
TTSBaseModel._device = None
@ -70,6 +86,17 @@ async def test_setup_cuda_unavailable(
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()

View file

@ -8,7 +8,7 @@ import requests
import sounddevice as sd
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
"""Stream TTS audio and play it back in real-time"""
print("\nStarting TTS stream request...")

View file

@ -46,6 +46,7 @@ test = [
"httpx==0.26.0",
"pytest-asyncio==0.23.5",
"gradio>=5",
"openai>=1.59.6",
]
[tool.uv]

View file

@ -1,14 +0,0 @@
# Core dependencies for testing
fastapi==0.115.6
uvicorn==0.34.0
pydantic==2.10.4
pydantic-settings==2.7.0
python-dotenv==1.0.1
sqlalchemy==2.0.27
# Testing
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

2
uv.lock generated
View file

@ -802,6 +802,7 @@ gpu = [
test = [
{ name = "gradio" },
{ name = "httpx" },
{ name = "openai" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
@ -819,6 +820,7 @@ requires-dist = [
{ name = "numpy", specifier = ">=1.26.0" },
{ name = "onnxruntime", specifier = "==1.20.1" },
{ name = "openai", specifier = ">=1.59.6" },
{ name = "openai", marker = "extra == 'test'", specifier = ">=1.59.6" },
{ name = "phonemizer", specifier = "==3.3.0" },
{ name = "pydantic", specifier = "==2.10.4" },
{ name = "pydantic-settings", specifier = "==2.7.0" },