Ruff checks, ci fix

This commit is contained in:
remsky 2025-01-13 20:15:46 -07:00
parent 007b1a35e8
commit 22752900e5
24 changed files with 73 additions and 73 deletions

View file

@ -2,16 +2,16 @@ name: CI
on: on:
push: push:
branches: [ "develop", "master" ] branches: [ "core/uv-management" ]
pull_request: pull_request:
branches: [ "develop", "master" ] branches: [ "core/uv-management" ]
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.9", "3.10", "3.11"] python-version: ["3.10"]
fail-fast: false fail-fast: false
steps: steps:
@ -22,30 +22,20 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Set up pip cache - name: Install uv
uses: actions/cache@v3 uses: astral-sh/setup-uv@v5
with: with:
path: ~/.cache/pip enable-cache: true
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }} cache-dependency-glob: "uv.lock"
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyTorch CPU
run: |
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install ruff pytest-cov uv pip install -e "docker/cpu[test]"
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Lint with ruff - name: Lint with ruff
run: | run: |
ruff check . uv run ruff check .
- name: Test API and UI
- name: Test API
run: | run: |
pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing uv run pytest api/tests/ ui/tests/ --asyncio-mode=auto --cov=api --cov=ui --cov-report=term-missing

View file

@ -1,11 +1,12 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01): def init_weights(m, mean=0.0, std=0.01):

View file

@ -1,7 +1,9 @@
import phonemizer
import re import re
import phonemizer
import torch import torch
def split_num(num): def split_num(num):
num = num.group() num = num.group()
if '.' in num: if '.' in num:

View file

@ -1,16 +1,19 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py # https://github.com/yl4579/StyleTTS2/blob/main/models.py
from .istftnet import AdaIN1d, Decoder
from munch import Munch
from pathlib import Path
from .plbert import load_plbert
from torch.nn.utils import weight_norm, spectral_norm
import json import json
import numpy as np
import os import os
import os.path as osp import os.path as osp
from pathlib import Path
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from munch import Munch
from torch.nn.utils import spectral_norm, weight_norm
from .istftnet import AdaIN1d, Decoder
from .plbert import load_plbert
class LinearNorm(torch.nn.Module): class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):

View file

@ -1,6 +1,7 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
from transformers import AlbertConfig, AlbertModel from transformers import AlbertConfig, AlbertModel
class CustomAlbert(AlbertModel): class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
# Call the original forward method # Call the original forward method

View file

@ -1,7 +1,7 @@
import re import re
import torch
import phonemizer import phonemizer
import torch
def split_num(num): def split_num(num):

View file

@ -6,15 +6,15 @@ import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import uvicorn import uvicorn
from loguru import logger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings from .core.config import settings
from .services.tts_model import TTSModel
from .routers.development import router as dev_router from .routers.development import router as dev_router
from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router from .routers.openai_compatible import router as openai_router
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
def setup_logger(): def setup_logger():

View file

@ -1,18 +1,18 @@
from typing import List from typing import List
import numpy as np import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService from ..services.audio import AudioService
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService from ..services.tts_service import TTSService
from ..structures.text_schemas import ( from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest, PhonemeRequest,
PhonemeResponse, PhonemeResponse,
GenerateFromPhonemesRequest,
) )
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"]) router = APIRouter(tags=["text processing"])

View file

@ -1,12 +1,12 @@
from typing import List, Union, AsyncGenerator from typing import AsyncGenerator, List, Union
from loguru import logger from fastapi import APIRouter, Depends, Header, HTTPException, Response
from fastapi import Header, Depends, Response, APIRouter, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from loguru import logger
from ..services.audio import AudioService from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts_service import TTSService from ..services.tts_service import TTSService
from ..structures.schemas import OpenAISpeechRequest
router = APIRouter( router = APIRouter(
tags=["OpenAI Compatible TTS"], tags=["OpenAI Compatible TTS"],

View file

@ -3,8 +3,8 @@
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
import soundfile as sf
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
import soundfile as sf
from loguru import logger from loguru import logger
from ..core.config import settings from ..core.config import settings
@ -23,6 +23,9 @@ class AudioNormalizer:
self, audio_data: np.ndarray, is_last_chunk: bool = False self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray: ) -> np.ndarray:
"""Convert audio data to int16 range and trim chunk boundaries""" """Convert audio data to int16 range and trim chunk boundaries"""
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
# Simple float32 to int16 conversion # Simple float32 to int16 conversion
audio_float = audio_data.astype(np.float32) audio_float = audio_data.astype(np.float32)

View file

@ -1,6 +1,6 @@
from .normalizer import normalize_text from .normalizer import normalize_text
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, tokenize, decode_tokens from .vocabulary import VOCAB, decode_tokens, tokenize
__all__ = [ __all__ = [
"normalize_text", "normalize_text",

View file

@ -5,14 +5,14 @@ import torch
from loguru import logger from loguru import logger
from onnxruntime import ( from onnxruntime import (
ExecutionMode, ExecutionMode,
SessionOptions,
InferenceSession,
GraphOptimizationLevel, GraphOptimizationLevel,
InferenceSession,
SessionOptions,
) )
from .tts_base import TTSBaseModel
from ..core.config import settings from ..core.config import settings
from .text_processing import tokenize, phonemize from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
class TTSCPUModel(TTSBaseModel): class TTSCPUModel(TTSBaseModel):

View file

@ -3,12 +3,12 @@ import time
import numpy as np import numpy as np
import torch import torch
from loguru import logger
from builds.models import build_model from builds.models import build_model
from loguru import logger
from .tts_base import TTSBaseModel
from ..core.config import settings from ..core.config import settings
from .text_processing import tokenize, phonemize from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
# @torch.no_grad() # @torch.no_grad()

View file

@ -2,19 +2,19 @@ import io
import os import os
import re import re
import time import time
from typing import List, Tuple, Optional
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Tuple
import numpy as np
import torch
import aiofiles.os import aiofiles.os
import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
import torch
from loguru import logger from loguru import logger
from .audio import AudioService, AudioNormalizer
from .tts_model import TTSModel
from ..core.config import settings from ..core.config import settings
from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text from .text_processing import chunker, normalize_text
from .tts_model import TTSModel
class TTSService: class TTSService:

View file

@ -4,9 +4,9 @@ from typing import List, Tuple
import torch import torch
from loguru import logger from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel from .tts_model import TTSModel
from .tts_service import TTSService from .tts_service import TTSService
from ..core.config import settings
class WarmupService: class WarmupService:

View file

@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import List, Union, Literal from typing import List, Literal, Union
from pydantic import Field, BaseModel from pydantic import BaseModel, Field
class VoiceCombineRequest(BaseModel): class VoiceCombineRequest(BaseModel):

View file

@ -1,4 +1,4 @@
from pydantic import Field, BaseModel from pydantic import BaseModel, Field
class PhonemeRequest(BaseModel): class PhonemeRequest(BaseModel):

View file

@ -1,11 +1,11 @@
import os import os
import sys
import shutil import shutil
from unittest.mock import Mock, MagicMock, patch import sys
from unittest.mock import MagicMock, Mock, patch
import aiofiles.threadpool
import numpy as np import numpy as np
import pytest import pytest
import aiofiles.threadpool
def cleanup_mock_dirs(): def cleanup_mock_dirs():

View file

@ -5,7 +5,7 @@ from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
from api.src.services.audio import AudioService, AudioNormalizer from api.src.services.audio import AudioNormalizer, AudioService
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View file

@ -1,10 +1,10 @@
import asyncio import asyncio
from unittest.mock import Mock, AsyncMock from unittest.mock import AsyncMock, Mock
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app from ..src.main import app

View file

@ -7,8 +7,8 @@ import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
from .conftest import MockTTSModel
from ..src.main import app from ..src.main import app
from .conftest import MockTTSModel
@pytest_asyncio.fixture @pytest_asyncio.fixture

View file

@ -1,15 +1,15 @@
"""Tests for TTS model implementations""" """Tests for TTS model implementations"""
import os import os
from unittest.mock import MagicMock, patch, AsyncMock from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import torch
import pytest import pytest
import torch
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
from api.src.services.tts_base import TTSBaseModel
# Base Model Tests # Base Model Tests

View file

@ -4,8 +4,8 @@ import os
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
import numpy as np import numpy as np
import torch
import pytest import pytest
import torch
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from api.src.core.config import settings from api.src.core.config import settings

View file

@ -54,7 +54,7 @@ def test_model_column_default_values():
def test_model_column_no_voices(): def test_model_column_no_voices():
"""Test model column creation with no voice IDs""" """Test model column creation with no voice IDs"""
_, components = create_model_column() _, components = create_model_column([])
assert components["voice"].choices == [] assert components["voice"].choices == []
assert components["voice"].value is None assert components["voice"].value is None
@ -96,7 +96,7 @@ def test_output_column_configuration():
# Test output files dropdown # Test output files dropdown
assert components["output_files"].label == "Previous Outputs" assert components["output_files"].label == "Previous Outputs"
assert components["output_files"].allow_custom_value is False assert components["output_files"].allow_custom_value is True
# Test play button # Test play button
assert components["play_btn"].value == "▶️ Play Selected" assert components["play_btn"].value == "▶️ Play Selected"