mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-21 05:44:06 +00:00
Ruff checks, ci fix
This commit is contained in:
parent
007b1a35e8
commit
22752900e5
24 changed files with 73 additions and 73 deletions
32
.github/workflows/ci.yml
vendored
32
.github/workflows/ci.yml
vendored
|
@ -2,16 +2,16 @@ name: CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "develop", "master" ]
|
branches: [ "core/uv-management" ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "develop", "master" ]
|
branches: [ "core/uv-management" ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11"]
|
python-version: ["3.10"]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
@ -22,30 +22,20 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up pip cache
|
- name: Install uv
|
||||||
uses: actions/cache@v3
|
uses: astral-sh/setup-uv@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pip
|
enable-cache: true
|
||||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
|
cache-dependency-glob: "uv.lock"
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pip-
|
|
||||||
|
|
||||||
- name: Install PyTorch CPU
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install ruff pytest-cov
|
uv pip install -e "docker/cpu[test]"
|
||||||
pip install -r requirements.txt
|
|
||||||
pip install -r requirements-test.txt
|
|
||||||
|
|
||||||
- name: Lint with ruff
|
- name: Lint with ruff
|
||||||
run: |
|
run: |
|
||||||
ruff check .
|
uv run ruff check .
|
||||||
|
|
||||||
|
- name: Test API and UI
|
||||||
- name: Test API
|
|
||||||
run: |
|
run: |
|
||||||
pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
|
uv run pytest api/tests/ ui/tests/ --asyncio-mode=auto --cov=api --cov=ui --cov-report=term-missing
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||||
from scipy.signal import get_window
|
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import phonemizer
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import phonemizer
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def split_num(num):
|
def split_num(num):
|
||||||
num = num.group()
|
num = num.group()
|
||||||
if '.' in num:
|
if '.' in num:
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||||
from .istftnet import AdaIN1d, Decoder
|
|
||||||
from munch import Munch
|
|
||||||
from pathlib import Path
|
|
||||||
from .plbert import load_plbert
|
|
||||||
from torch.nn.utils import weight_norm, spectral_norm
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from munch import Munch
|
||||||
|
from torch.nn.utils import spectral_norm, weight_norm
|
||||||
|
|
||||||
|
from .istftnet import AdaIN1d, Decoder
|
||||||
|
from .plbert import load_plbert
|
||||||
|
|
||||||
|
|
||||||
class LinearNorm(torch.nn.Module):
|
class LinearNorm(torch.nn.Module):
|
||||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||||
from transformers import AlbertConfig, AlbertModel
|
from transformers import AlbertConfig, AlbertModel
|
||||||
|
|
||||||
|
|
||||||
class CustomAlbert(AlbertModel):
|
class CustomAlbert(AlbertModel):
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
# Call the original forward method
|
# Call the original forward method
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def split_num(num):
|
def split_num(num):
|
||||||
|
|
|
@ -6,15 +6,15 @@ import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from loguru import logger
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .services.tts_model import TTSModel
|
|
||||||
from .routers.development import router as dev_router
|
from .routers.development import router as dev_router
|
||||||
from .services.tts_service import TTSService
|
|
||||||
from .routers.openai_compatible import router as openai_router
|
from .routers.openai_compatible import router as openai_router
|
||||||
|
from .services.tts_model import TTSModel
|
||||||
|
from .services.tts_service import TTSService
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
def setup_logger():
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
from ..services.text_processing import phonemize, tokenize
|
||||||
from ..services.tts_model import TTSModel
|
from ..services.tts_model import TTSModel
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
|
GenerateFromPhonemesRequest,
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
GenerateFromPhonemesRequest,
|
|
||||||
)
|
)
|
||||||
from ..services.text_processing import tokenize, phonemize
|
|
||||||
|
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from typing import List, Union, AsyncGenerator
|
from typing import AsyncGenerator, List, Union
|
||||||
|
|
||||||
from loguru import logger
|
from fastapi import APIRouter, Depends, Header, HTTPException, Response
|
||||||
from fastapi import Header, Depends, Response, APIRouter, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["OpenAI Compatible TTS"],
|
tags=["OpenAI Compatible TTS"],
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
import soundfile as sf
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
@ -23,6 +23,9 @@ class AudioNormalizer:
|
||||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
raise ValueError("Audio data cannot be empty")
|
||||||
|
|
||||||
# Simple float32 to int16 conversion
|
# Simple float32 to int16 conversion
|
||||||
audio_float = audio_data.astype(np.float32)
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||||
from .vocabulary import VOCAB, tokenize, decode_tokens
|
from .vocabulary import VOCAB, decode_tokens, tokenize
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"normalize_text",
|
"normalize_text",
|
||||||
|
|
|
@ -5,14 +5,14 @@ import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from onnxruntime import (
|
from onnxruntime import (
|
||||||
ExecutionMode,
|
ExecutionMode,
|
||||||
SessionOptions,
|
|
||||||
InferenceSession,
|
|
||||||
GraphOptimizationLevel,
|
GraphOptimizationLevel,
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .text_processing import tokenize, phonemize
|
from .text_processing import phonemize, tokenize
|
||||||
|
from .tts_base import TTSBaseModel
|
||||||
|
|
||||||
|
|
||||||
class TTSCPUModel(TTSBaseModel):
|
class TTSCPUModel(TTSBaseModel):
|
||||||
|
|
|
@ -3,12 +3,12 @@ import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
from builds.models import build_model
|
from builds.models import build_model
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .text_processing import tokenize, phonemize
|
from .text_processing import phonemize, tokenize
|
||||||
|
from .tts_base import TTSBaseModel
|
||||||
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
# @torch.no_grad()
|
||||||
|
|
|
@ -2,19 +2,19 @@ import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .audio import AudioService, AudioNormalizer
|
|
||||||
from .tts_model import TTSModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text
|
from .text_processing import chunker, normalize_text
|
||||||
|
from .tts_model import TTSModel
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
|
|
@ -4,9 +4,9 @@ from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
from .tts_model import TTSModel
|
from .tts_model import TTSModel
|
||||||
from .tts_service import TTSService
|
from .tts_service import TTSService
|
||||||
from ..core.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupService:
|
class WarmupService:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Union, Literal
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class VoiceCombineRequest(BaseModel):
|
class VoiceCombineRequest(BaseModel):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class PhonemeRequest(BaseModel):
|
class PhonemeRequest(BaseModel):
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import shutil
|
import shutil
|
||||||
from unittest.mock import Mock, MagicMock, patch
|
import sys
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import aiofiles.threadpool
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import aiofiles.threadpool
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_mock_dirs():
|
def cleanup_mock_dirs():
|
||||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.audio import AudioService, AudioNormalizer
|
from api.src.services.audio import AudioNormalizer, AudioService
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import Mock, AsyncMock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from httpx import AsyncClient
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from ..src.main import app
|
from ..src.main import app
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from .conftest import MockTTSModel
|
|
||||||
from ..src.main import app
|
from ..src.main import app
|
||||||
|
from .conftest import MockTTSModel
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
"""Tests for TTS model implementations"""
|
"""Tests for TTS model implementations"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch, AsyncMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from api.src.services.tts_base import TTSBaseModel
|
||||||
from api.src.services.tts_cpu import TTSCPUModel
|
from api.src.services.tts_cpu import TTSCPUModel
|
||||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||||
from api.src.services.tts_base import TTSBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
# Base Model Tests
|
# Base Model Tests
|
||||||
|
|
|
@ -4,8 +4,8 @@ import os
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
from api.src.core.config import settings
|
from api.src.core.config import settings
|
||||||
|
|
|
@ -54,7 +54,7 @@ def test_model_column_default_values():
|
||||||
|
|
||||||
def test_model_column_no_voices():
|
def test_model_column_no_voices():
|
||||||
"""Test model column creation with no voice IDs"""
|
"""Test model column creation with no voice IDs"""
|
||||||
_, components = create_model_column()
|
_, components = create_model_column([])
|
||||||
|
|
||||||
assert components["voice"].choices == []
|
assert components["voice"].choices == []
|
||||||
assert components["voice"].value is None
|
assert components["voice"].value is None
|
||||||
|
@ -96,7 +96,7 @@ def test_output_column_configuration():
|
||||||
|
|
||||||
# Test output files dropdown
|
# Test output files dropdown
|
||||||
assert components["output_files"].label == "Previous Outputs"
|
assert components["output_files"].label == "Previous Outputs"
|
||||||
assert components["output_files"].allow_custom_value is False
|
assert components["output_files"].allow_custom_value is True
|
||||||
|
|
||||||
# Test play button
|
# Test play button
|
||||||
assert components["play_btn"].value == "▶️ Play Selected"
|
assert components["play_btn"].value == "▶️ Play Selected"
|
||||||
|
|
Loading…
Add table
Reference in a new issue