mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-21 05:44:06 +00:00
Ruff checks, ci fix
This commit is contained in:
parent
007b1a35e8
commit
22752900e5
24 changed files with 73 additions and 73 deletions
32
.github/workflows/ci.yml
vendored
32
.github/workflows/ci.yml
vendored
|
@ -2,16 +2,16 @@ name: CI
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ "develop", "master" ]
|
||||
branches: [ "core/uv-management" ]
|
||||
pull_request:
|
||||
branches: [ "develop", "master" ]
|
||||
branches: [ "core/uv-management" ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11"]
|
||||
python-version: ["3.10"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
|
@ -22,30 +22,20 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up pip cache
|
||||
uses: actions/cache@v3
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install PyTorch CPU
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ruff pytest-cov
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements-test.txt
|
||||
uv pip install -e "docker/cpu[test]"
|
||||
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
ruff check .
|
||||
uv run ruff check .
|
||||
|
||||
|
||||
- name: Test API
|
||||
- name: Test API and UI
|
||||
run: |
|
||||
pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
|
||||
uv run pytest api/tests/ ui/tests/ --asyncio-mode=auto --cov=api --cov=ui --cov-report=term-missing
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||
from scipy.signal import get_window
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import get_window
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import phonemizer
|
||||
import re
|
||||
|
||||
import phonemizer
|
||||
import torch
|
||||
|
||||
|
||||
def split_num(num):
|
||||
num = num.group()
|
||||
if '.' in num:
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||
from .istftnet import AdaIN1d, Decoder
|
||||
from munch import Munch
|
||||
from pathlib import Path
|
||||
from .plbert import load_plbert
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import Munch
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from .istftnet import AdaIN1d, Decoder
|
||||
from .plbert import load_plbert
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||
from transformers import AlbertConfig, AlbertModel
|
||||
|
||||
|
||||
class CustomAlbert(AlbertModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
# Call the original forward method
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
import phonemizer
|
||||
import torch
|
||||
|
||||
|
||||
def split_num(num):
|
||||
|
|
|
@ -6,15 +6,15 @@ import sys
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from .core.config import settings
|
||||
from .services.tts_model import TTSModel
|
||||
from .routers.development import router as dev_router
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .services.tts_model import TTSModel
|
||||
from .services.tts_service import TTSService
|
||||
|
||||
|
||||
def setup_logger():
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.text_processing import phonemize, tokenize
|
||||
from ..services.tts_model import TTSModel
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
GenerateFromPhonemesRequest,
|
||||
)
|
||||
from ..services.text_processing import tokenize, phonemize
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from typing import List, Union, AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import Header, Depends, Response, APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
||||
router = APIRouter(
|
||||
tags=["OpenAI Compatible TTS"],
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import scipy.io.wavfile as wavfile
|
||||
import soundfile as sf
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
@ -23,6 +23,9 @@ class AudioNormalizer:
|
|||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Audio data cannot be empty")
|
||||
|
||||
# Simple float32 to int16 conversion
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .normalizer import normalize_text
|
||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||
from .vocabulary import VOCAB, tokenize, decode_tokens
|
||||
from .vocabulary import VOCAB, decode_tokens, tokenize
|
||||
|
||||
__all__ = [
|
||||
"normalize_text",
|
||||
|
|
|
@ -5,14 +5,14 @@ import torch
|
|||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
SessionOptions,
|
||||
InferenceSession,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
)
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
|
|
|
@ -3,12 +3,12 @@ import time
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from builds.models import build_model
|
||||
from loguru import logger
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
|
|
|
@ -2,19 +2,19 @@ import io
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import aiofiles.os
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .audio import AudioService, AudioNormalizer
|
||||
from .tts_model import TTSModel
|
||||
from ..core.config import settings
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import chunker, normalize_text
|
||||
from .tts_model import TTSModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
|
|
|
@ -4,9 +4,9 @@ from typing import List, Tuple
|
|||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class WarmupService:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from typing import List, Union, Literal
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VoiceCombineRequest(BaseModel):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import Field, BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import sys
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import aiofiles.threadpool
|
||||
import numpy as np
|
||||
import pytest
|
||||
import aiofiles.threadpool
|
||||
|
||||
|
||||
def cleanup_mock_dirs():
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from api.src.services.audio import AudioService, AudioNormalizer
|
||||
from api.src.services.audio import AudioNormalizer, AudioService
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..src.main import app
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ import pytest
|
|||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
from .conftest import MockTTSModel
|
||||
from ..src.main import app
|
||||
from .conftest import MockTTSModel
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# Base Model Tests
|
||||
|
|
|
@ -4,8 +4,8 @@ import os
|
|||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from api.src.core.config import settings
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_model_column_default_values():
|
|||
|
||||
def test_model_column_no_voices():
|
||||
"""Test model column creation with no voice IDs"""
|
||||
_, components = create_model_column()
|
||||
_, components = create_model_column([])
|
||||
|
||||
assert components["voice"].choices == []
|
||||
assert components["voice"].value is None
|
||||
|
@ -96,7 +96,7 @@ def test_output_column_configuration():
|
|||
|
||||
# Test output files dropdown
|
||||
assert components["output_files"].label == "Previous Outputs"
|
||||
assert components["output_files"].allow_custom_value is False
|
||||
assert components["output_files"].allow_custom_value is True
|
||||
|
||||
# Test play button
|
||||
assert components["play_btn"].value == "▶️ Play Selected"
|
||||
|
|
Loading…
Add table
Reference in a new issue