Ruff checks, ci fix

This commit is contained in:
remsky 2025-01-13 20:15:46 -07:00
parent 007b1a35e8
commit 22752900e5
24 changed files with 73 additions and 73 deletions

View file

@ -2,16 +2,16 @@ name: CI
on:
push:
branches: [ "develop", "master" ]
branches: [ "core/uv-management" ]
pull_request:
branches: [ "develop", "master" ]
branches: [ "core/uv-management" ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10"]
fail-fast: false
steps:
@ -22,30 +22,20 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Set up pip cache
uses: actions/cache@v3
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyTorch CPU
run: |
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Install dependencies
run: |
pip install ruff pytest-cov
pip install -r requirements.txt
pip install -r requirements-test.txt
uv pip install -e "docker/cpu[test]"
- name: Lint with ruff
run: |
ruff check .
uv run ruff check .
- name: Test API
- name: Test API and UI
run: |
pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
uv run pytest api/tests/ ui/tests/ --asyncio-mode=auto --cov=api --cov=ui --cov-report=term-missing

View file

@ -1,11 +1,12 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01):

View file

@ -1,7 +1,9 @@
import phonemizer
import re
import phonemizer
import torch
def split_num(num):
num = num.group()
if '.' in num:

View file

@ -1,16 +1,19 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
from .istftnet import AdaIN1d, Decoder
from munch import Munch
from pathlib import Path
from .plbert import load_plbert
from torch.nn.utils import weight_norm, spectral_norm
import json
import numpy as np
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import Munch
from torch.nn.utils import spectral_norm, weight_norm
from .istftnet import AdaIN1d, Decoder
from .plbert import load_plbert
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):

View file

@ -1,6 +1,7 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
from transformers import AlbertConfig, AlbertModel
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
# Call the original forward method

View file

@ -1,7 +1,7 @@
import re
import torch
import phonemizer
import torch
def split_num(num):

View file

@ -6,15 +6,15 @@ import sys
from contextlib import asynccontextmanager
import uvicorn
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings
from .services.tts_model import TTSModel
from .routers.development import router as dev_router
from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
def setup_logger():

View file

@ -1,18 +1,18 @@
from typing import List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
GenerateFromPhonemesRequest,
)
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"])

View file

@ -1,12 +1,12 @@
from typing import List, Union, AsyncGenerator
from typing import AsyncGenerator, List, Union
from loguru import logger
from fastapi import Header, Depends, Response, APIRouter, HTTPException
from fastapi import APIRouter, Depends, Header, HTTPException, Response
from fastapi.responses import StreamingResponse
from loguru import logger
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts_service import TTSService
from ..structures.schemas import OpenAISpeechRequest
router = APIRouter(
tags=["OpenAI Compatible TTS"],

View file

@ -3,8 +3,8 @@
from io import BytesIO
import numpy as np
import soundfile as sf
import scipy.io.wavfile as wavfile
import soundfile as sf
from loguru import logger
from ..core.config import settings
@ -23,6 +23,9 @@ class AudioNormalizer:
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Convert audio data to int16 range and trim chunk boundaries"""
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
# Simple float32 to int16 conversion
audio_float = audio_data.astype(np.float32)

View file

@ -1,6 +1,6 @@
from .normalizer import normalize_text
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, tokenize, decode_tokens
from .vocabulary import VOCAB, decode_tokens, tokenize
__all__ = [
"normalize_text",

View file

@ -5,14 +5,14 @@ import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
SessionOptions,
InferenceSession,
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
class TTSCPUModel(TTSBaseModel):

View file

@ -3,12 +3,12 @@ import time
import numpy as np
import torch
from loguru import logger
from builds.models import build_model
from loguru import logger
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
# @torch.no_grad()

View file

@ -2,19 +2,19 @@ import io
import os
import re
import time
from typing import List, Tuple, Optional
from functools import lru_cache
from typing import List, Optional, Tuple
import numpy as np
import torch
import aiofiles.os
import numpy as np
import scipy.io.wavfile as wavfile
import torch
from loguru import logger
from .audio import AudioService, AudioNormalizer
from .tts_model import TTSModel
from ..core.config import settings
from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text
from .tts_model import TTSModel
class TTSService:

View file

@ -4,9 +4,9 @@ from typing import List, Tuple
import torch
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_service import TTSService
from ..core.config import settings
class WarmupService:

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Union, Literal
from typing import List, Literal, Union
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
class VoiceCombineRequest(BaseModel):

View file

@ -1,4 +1,4 @@
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
class PhonemeRequest(BaseModel):

View file

@ -1,11 +1,11 @@
import os
import sys
import shutil
from unittest.mock import Mock, MagicMock, patch
import sys
from unittest.mock import MagicMock, Mock, patch
import aiofiles.threadpool
import numpy as np
import pytest
import aiofiles.threadpool
def cleanup_mock_dirs():

View file

@ -5,7 +5,7 @@ from unittest.mock import patch
import numpy as np
import pytest
from api.src.services.audio import AudioService, AudioNormalizer
from api.src.services.audio import AudioNormalizer, AudioService
@pytest.fixture(autouse=True)

View file

@ -1,10 +1,10 @@
import asyncio
from unittest.mock import Mock, AsyncMock
from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from httpx import AsyncClient
from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app

View file

@ -7,8 +7,8 @@ import pytest
import pytest_asyncio
from httpx import AsyncClient
from .conftest import MockTTSModel
from ..src.main import app
from .conftest import MockTTSModel
@pytest_asyncio.fixture

View file

@ -1,15 +1,15 @@
"""Tests for TTS model implementations"""
import os
from unittest.mock import MagicMock, patch, AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import torch
import pytest
import torch
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
from api.src.services.tts_base import TTSBaseModel
# Base Model Tests

View file

@ -4,8 +4,8 @@ import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import torch
import pytest
import torch
from onnxruntime import InferenceSession
from api.src.core.config import settings

View file

@ -54,7 +54,7 @@ def test_model_column_default_values():
def test_model_column_no_voices():
"""Test model column creation with no voice IDs"""
_, components = create_model_column()
_, components = create_model_column([])
assert components["voice"].choices == []
assert components["voice"].value is None
@ -96,7 +96,7 @@ def test_output_column_configuration():
# Test output files dropdown
assert components["output_files"].label == "Previous Outputs"
assert components["output_files"].allow_custom_value is False
assert components["output_files"].allow_custom_value is True
# Test play button
assert components["play_btn"].value == "▶️ Play Selected"