mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Ruff format + fix
This commit is contained in:
parent
f6e3afa14c
commit
e8c1284032
31 changed files with 927 additions and 624 deletions
|
@ -2,8 +2,8 @@
|
|||
FastAPI OpenAI Compatible API
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
|
@ -12,9 +12,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from .core.config import settings
|
||||
from .services.tts_model import TTSModel
|
||||
from .routers.development import router as dev_router
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .routers.development import router as dev_router
|
||||
|
||||
|
||||
def setup_logger():
|
||||
|
@ -27,22 +27,18 @@ def setup_logger():
|
|||
"{level: <8} | "
|
||||
"{message}",
|
||||
"colorize": True,
|
||||
"level": "INFO"
|
||||
"level": "INFO",
|
||||
},
|
||||
],
|
||||
}
|
||||
# Remove default logger
|
||||
logger.remove()
|
||||
# Add our custom logger
|
||||
logger.configure(**config)
|
||||
# Override error colors
|
||||
logger.level("ERROR", color="<red>")
|
||||
|
||||
|
||||
# Configure logger
|
||||
setup_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for model initialization"""
|
||||
|
@ -52,7 +48,7 @@ async def lifespan(app: FastAPI):
|
|||
voicepack_count = await TTSModel.setup()
|
||||
# boundary = "█████╗"*9
|
||||
boundary = "░" * 24
|
||||
startup_msg =f"""
|
||||
startup_msg = f"""
|
||||
|
||||
{boundary}
|
||||
|
||||
|
|
|
@ -1,24 +1,30 @@
|
|||
from typing import List
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, HTTPException, Depends, Response
|
||||
from ..structures.text_schemas import PhonemeRequest, PhonemeResponse, GenerateFromPhonemesRequest
|
||||
from ..services.text_processing import phonemize, tokenize
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.tts_model import TTSModel
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_model import TTSModel
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.text_schemas import (
|
||||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
GenerateFromPhonemesRequest,
|
||||
)
|
||||
from ..services.text_processing import tokenize, phonemize
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance"""
|
||||
return TTSService()
|
||||
|
||||
|
||||
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
||||
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
||||
async def phonemize_text(
|
||||
request: PhonemeRequest
|
||||
) -> PhonemeResponse:
|
||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||
"""Convert text to phonemes and tokens
|
||||
|
||||
Args:
|
||||
|
@ -41,28 +47,24 @@ async def phonemize_text(
|
|||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
return PhonemeResponse(
|
||||
phonemes=phonemes,
|
||||
tokens=tokens
|
||||
)
|
||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
|
||||
@router.post("/dev/generate_from_phonemes")
|
||||
async def generate_from_phonemes(
|
||||
request: GenerateFromPhonemesRequest,
|
||||
tts_service: TTSService = Depends(get_tts_service)
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
) -> Response:
|
||||
"""Generate audio directly from phonemes
|
||||
|
||||
|
@ -77,7 +79,7 @@ async def generate_from_phonemes(
|
|||
if not request.phonemes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"}
|
||||
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
|
||||
)
|
||||
|
||||
# Validate voice exists
|
||||
|
@ -85,7 +87,10 @@ async def generate_from_phonemes(
|
|||
if not voice_path:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": f"Voice not found: {request.voice}"}
|
||||
detail={
|
||||
"error": "Invalid request",
|
||||
"message": f"Voice not found: {request.voice}",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -101,12 +106,7 @@ async def generate_from_phonemes(
|
|||
|
||||
# Convert to WAV bytes
|
||||
wav_bytes = AudioService.convert_audio(
|
||||
audio,
|
||||
24000,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True,
|
||||
stream=False
|
||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
|
||||
)
|
||||
|
||||
return Response(
|
||||
|
@ -115,18 +115,16 @@ async def generate_from_phonemes(
|
|||
headers={
|
||||
"Content-Disposition": "attachment; filename=speech.wav",
|
||||
"Cache-Control": "no-cache",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating audio: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from typing import List, Union
|
||||
from typing import List, Union, AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
from fastapi import Header
|
||||
from fastapi import Header, Depends, Response, APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ..services.tts_service import TTSService
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from typing import AsyncGenerator
|
||||
from ..services.tts_service import TTSService
|
||||
|
||||
router = APIRouter(
|
||||
tags=["OpenAI Compatible TTS"],
|
||||
|
@ -20,7 +19,9 @@ def get_tts_service() -> TTSService:
|
|||
return TTSService() # Initialize TTSService with default settings
|
||||
|
||||
|
||||
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
||||
async def process_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
"""Process voice input into a combined voice, handling both string and list formats"""
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
|
@ -35,7 +36,9 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer
|
|||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
||||
raise ValueError(
|
||||
f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
# If single voice, return it directly
|
||||
if len(voices) == 1:
|
||||
|
@ -45,14 +48,16 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer
|
|||
return await tts_service.combine_voices(voices=voices)
|
||||
|
||||
|
||||
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format
|
||||
output_format=request.response_format,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
@ -101,11 +106,8 @@ async def create_speech(
|
|||
|
||||
# Convert to requested format
|
||||
content = AudioService.convert_audio(
|
||||
audio,
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=True,
|
||||
stream=False)
|
||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
|
|
|
@ -6,17 +6,22 @@ import numpy as np
|
|||
import soundfile as sf
|
||||
import scipy.io.wavfile as wavfile
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class AudioNormalizer:
|
||||
"""Handles audio normalization state for a single stream"""
|
||||
|
||||
def __init__(self):
|
||||
self.int16_max = np.iinfo(np.int16).max
|
||||
self.chunk_trim_ms = settings.gap_trim_ms
|
||||
self.sample_rate = 24000 # Sample rate of the audio
|
||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||
|
||||
def normalize(self, audio_data: np.ndarray, is_last_chunk: bool = False) -> np.ndarray:
|
||||
def normalize(
|
||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Normalize audio data to int16 range and trim chunk boundaries"""
|
||||
# Convert to float32 if not already
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
@ -27,11 +32,12 @@ class AudioNormalizer:
|
|||
|
||||
# Trim end of non-final chunks to reduce gaps
|
||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||
audio_float = audio_float[:-self.samples_to_trim]
|
||||
audio_float = audio_float[: -self.samples_to_trim]
|
||||
|
||||
# Scale to int16 range
|
||||
return (audio_float * self.int16_max).astype(np.int16)
|
||||
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions"""
|
||||
|
||||
|
@ -46,7 +52,7 @@ class AudioService:
|
|||
},
|
||||
"flac": {
|
||||
"compression_level": 0.0, # Light compression, still fast
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -58,7 +64,7 @@ class AudioService:
|
|||
is_last_chunk: bool = False,
|
||||
normalizer: AudioNormalizer = None,
|
||||
format_settings: dict = None,
|
||||
stream: bool = True
|
||||
stream: bool = True,
|
||||
) -> bytes:
|
||||
"""Convert audio data to specified format
|
||||
|
||||
|
@ -90,37 +96,55 @@ class AudioService:
|
|||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(audio_data, is_last_chunk=is_last_chunk)
|
||||
normalized_audio = normalizer.normalize(
|
||||
audio_data, is_last_chunk=is_last_chunk
|
||||
)
|
||||
|
||||
if output_format == "pcm":
|
||||
# Raw 16-bit PCM samples, no header
|
||||
buffer.write(normalized_audio.tobytes())
|
||||
elif output_format == "wav":
|
||||
# Always use soundfile for WAV to ensure proper headers and normalization
|
||||
sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16')
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="WAV",
|
||||
subtype="PCM_16",
|
||||
)
|
||||
elif output_format == "mp3":
|
||||
# Use format settings or defaults
|
||||
settings = format_settings.get("mp3", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
|
||||
sf.write(
|
||||
buffer, normalized_audio,
|
||||
sample_rate, format="MP3",
|
||||
**settings
|
||||
buffer, normalized_audio, sample_rate, format="MP3", **settings
|
||||
)
|
||||
|
||||
elif output_format == "opus":
|
||||
settings = format_settings.get("opus", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
|
||||
sf.write(buffer, normalized_audio, sample_rate, format="OGG",
|
||||
subtype="OPUS", **settings)
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="OGG",
|
||||
subtype="OPUS",
|
||||
**settings,
|
||||
)
|
||||
|
||||
elif output_format == "flac":
|
||||
if is_first_chunk:
|
||||
logger.info("Starting FLAC stream...")
|
||||
settings = format_settings.get("flac", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
|
||||
sf.write(buffer, normalized_audio, sample_rate, format="FLAC",
|
||||
subtype='PCM_16', **settings)
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="FLAC",
|
||||
subtype="PCM_16",
|
||||
**settings,
|
||||
)
|
||||
else:
|
||||
if output_format == "aac":
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from .normalizer import normalize_text
|
||||
from .phonemizer import phonemize, PhonemizerBackend, EspeakBackend
|
||||
from .vocabulary import tokenize, decode_tokens, VOCAB
|
||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||
from .vocabulary import VOCAB, tokenize, decode_tokens
|
||||
|
||||
__all__ = [
|
||||
'normalize_text',
|
||||
'phonemize',
|
||||
'tokenize',
|
||||
'decode_tokens',
|
||||
'VOCAB',
|
||||
'PhonemizerBackend',
|
||||
'EspeakBackend'
|
||||
"normalize_text",
|
||||
"phonemize",
|
||||
"tokenize",
|
||||
"decode_tokens",
|
||||
"VOCAB",
|
||||
"PhonemizerBackend",
|
||||
"EspeakBackend",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Text chunking service"""
|
||||
|
||||
import re
|
||||
|
||||
from ...core.config import settings
|
||||
|
||||
|
||||
|
|
|
@ -9,19 +9,58 @@ from functools import lru_cache
|
|||
|
||||
# Constants
|
||||
VALID_TLDS = [
|
||||
"com", "org", "net", "edu", "gov", "mil", "int", "biz", "info", "name",
|
||||
"pro", "coop", "museum", "travel", "jobs", "mobi", "tel", "asia", "cat",
|
||||
"xxx", "aero", "arpa", "bg", "br", "ca", "cn", "de", "es", "eu", "fr",
|
||||
"in", "it", "jp", "mx", "nl", "ru", "uk", "us", "io"
|
||||
"com",
|
||||
"org",
|
||||
"net",
|
||||
"edu",
|
||||
"gov",
|
||||
"mil",
|
||||
"int",
|
||||
"biz",
|
||||
"info",
|
||||
"name",
|
||||
"pro",
|
||||
"coop",
|
||||
"museum",
|
||||
"travel",
|
||||
"jobs",
|
||||
"mobi",
|
||||
"tel",
|
||||
"asia",
|
||||
"cat",
|
||||
"xxx",
|
||||
"aero",
|
||||
"arpa",
|
||||
"bg",
|
||||
"br",
|
||||
"ca",
|
||||
"cn",
|
||||
"de",
|
||||
"es",
|
||||
"eu",
|
||||
"fr",
|
||||
"in",
|
||||
"it",
|
||||
"jp",
|
||||
"mx",
|
||||
"nl",
|
||||
"ru",
|
||||
"uk",
|
||||
"us",
|
||||
"io",
|
||||
]
|
||||
|
||||
# Pre-compiled regex patterns for performance
|
||||
EMAIL_PATTERN = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE)
|
||||
URL_PATTERN = re.compile(
|
||||
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" +
|
||||
"|".join(VALID_TLDS) + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
|
||||
re.IGNORECASE
|
||||
EMAIL_PATTERN = re.compile(
|
||||
r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
|
||||
)
|
||||
URL_PATTERN = re.compile(
|
||||
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:"
|
||||
+ "|".join(VALID_TLDS)
|
||||
+ "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def split_num(num: re.Match[str]) -> str:
|
||||
"""Handle number splitting for various formats"""
|
||||
|
@ -47,6 +86,7 @@ def split_num(num: re.Match[str]) -> str:
|
|||
return f"{left} oh {right}{s}"
|
||||
return f"{left} {right}{s}"
|
||||
|
||||
|
||||
def handle_money(m: re.Match[str]) -> str:
|
||||
"""Convert money expressions to spoken form"""
|
||||
m = m.group()
|
||||
|
@ -66,21 +106,24 @@ def handle_money(m: re.Match[str]) -> str:
|
|||
)
|
||||
return f"{b} {bill}{s} and {c} {coins}"
|
||||
|
||||
|
||||
def handle_decimal(num: re.Match[str]) -> str:
|
||||
"""Convert decimal numbers to spoken form"""
|
||||
a, b = num.group().split(".")
|
||||
return " point ".join([a, " ".join(b)])
|
||||
|
||||
|
||||
def handle_email(m: re.Match[str]) -> str:
|
||||
"""Convert email addresses into speakable format"""
|
||||
email = m.group(0)
|
||||
parts = email.split('@')
|
||||
parts = email.split("@")
|
||||
if len(parts) == 2:
|
||||
user, domain = parts
|
||||
domain = domain.replace('.', ' dot ')
|
||||
domain = domain.replace(".", " dot ")
|
||||
return f"{user} at {domain}"
|
||||
return email
|
||||
|
||||
|
||||
def handle_url(u: re.Match[str]) -> str:
|
||||
"""Make URLs speakable by converting special characters to spoken words"""
|
||||
if not u:
|
||||
|
@ -89,19 +132,24 @@ def handle_url(u: re.Match[str]) -> str:
|
|||
url = u.group(0).strip()
|
||||
|
||||
# Handle protocol first
|
||||
url = re.sub(r'^https?://', lambda a: 'https ' if 'https' in a.group() else 'http ', url, flags=re.IGNORECASE)
|
||||
url = re.sub(r'^www\.', 'www ', url, flags=re.IGNORECASE)
|
||||
url = re.sub(
|
||||
r"^https?://",
|
||||
lambda a: "https " if "https" in a.group() else "http ",
|
||||
url,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
url = re.sub(r"^www\.", "www ", url, flags=re.IGNORECASE)
|
||||
|
||||
# Handle port numbers before other replacements
|
||||
url = re.sub(r':(\d+)(?=/|$)', lambda m: f" colon {m.group(1)}", url)
|
||||
url = re.sub(r":(\d+)(?=/|$)", lambda m: f" colon {m.group(1)}", url)
|
||||
|
||||
# Split into domain and path
|
||||
parts = url.split('/', 1)
|
||||
parts = url.split("/", 1)
|
||||
domain = parts[0]
|
||||
path = parts[1] if len(parts) > 1 else ''
|
||||
path = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Handle dots in domain
|
||||
domain = domain.replace('.', ' dot ')
|
||||
domain = domain.replace(".", " dot ")
|
||||
|
||||
# Reconstruct URL
|
||||
if path:
|
||||
|
@ -120,7 +168,7 @@ def handle_url(u: re.Match[str]) -> str:
|
|||
url = url.replace("/", " slash ") # Handle any remaining slashes
|
||||
|
||||
# Clean up extra spaces
|
||||
return re.sub(r'\s+', ' ', url).strip()
|
||||
return re.sub(r"\s+", " ", url).strip()
|
||||
|
||||
|
||||
def normalize_urls(text: str) -> str:
|
||||
|
@ -133,6 +181,7 @@ def normalize_urls(text: str) -> str:
|
|||
|
||||
return text
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text for TTS processing"""
|
||||
# Pre-process URLs first
|
||||
|
@ -165,9 +214,7 @@ def normalize_text(text: str) -> str:
|
|||
|
||||
# Handle numbers and money
|
||||
text = re.sub(
|
||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
|
||||
split_num,
|
||||
text
|
||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
||||
)
|
||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||
text = re.sub(
|
||||
|
@ -183,9 +230,7 @@ def normalize_text(text: str) -> str:
|
|||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||
text = re.sub(r"(?<=X')S\b", "s", text)
|
||||
text = re.sub(
|
||||
r"(?:[A-Za-z]\.){2,} [a-z]",
|
||||
lambda m: m.group().replace(".", "-"),
|
||||
text
|
||||
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
||||
)
|
||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import phonemizer
|
||||
|
||||
from .normalizer import normalize_text
|
||||
|
||||
|
||||
class PhonemizerBackend(ABC):
|
||||
"""Abstract base class for phonemization backends"""
|
||||
|
||||
|
@ -18,6 +21,7 @@ class PhonemizerBackend(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EspeakBackend(PhonemizerBackend):
|
||||
"""Espeak-based phonemizer implementation"""
|
||||
|
||||
|
@ -28,9 +32,7 @@ class EspeakBackend(PhonemizerBackend):
|
|||
language: Language code ('en-us' or 'en-gb')
|
||||
"""
|
||||
self.backend = phonemizer.backend.EspeakBackend(
|
||||
language=language,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True
|
||||
language=language, preserve_punctuation=True, with_stress=True
|
||||
)
|
||||
self.language = language
|
||||
|
||||
|
@ -59,6 +61,7 @@ class EspeakBackend(PhonemizerBackend):
|
|||
|
||||
return ps.strip()
|
||||
|
||||
|
||||
def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||
"""Factory function to create phonemizer backend
|
||||
|
||||
|
@ -69,16 +72,14 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
|||
Phonemizer backend instance
|
||||
"""
|
||||
# Map language codes to espeak language codes
|
||||
lang_map = {
|
||||
"a": "en-us",
|
||||
"b": "en-gb"
|
||||
}
|
||||
lang_map = {"a": "en-us", "b": "en-gb"}
|
||||
|
||||
if language not in lang_map:
|
||||
raise ValueError(f"Unsupported language code: {language}")
|
||||
|
||||
return EspeakBackend(lang_map[language])
|
||||
|
||||
|
||||
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||
"""Convert text to phonemes
|
||||
|
||||
|
|
|
@ -9,9 +9,11 @@ def get_vocab():
|
|||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
return {symbol: i for i, symbol in enumerate(symbols)}
|
||||
|
||||
|
||||
# Initialize vocabulary
|
||||
VOCAB = get_vocab()
|
||||
|
||||
|
||||
def tokenize(phonemes: str) -> list[int]:
|
||||
"""Convert phonemes string to token IDs
|
||||
|
||||
|
@ -23,6 +25,7 @@ def tokenize(phonemes: str) -> list[int]:
|
|||
"""
|
||||
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
||||
|
||||
|
||||
def decode_tokens(tokens: list[int]) -> str:
|
||||
"""Convert token IDs back to phonemes string
|
||||
|
||||
|
|
|
@ -2,12 +2,14 @@ import os
|
|||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class TTSBaseModel(ABC):
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
@ -26,7 +28,9 @@ class TTSBaseModel(ABC):
|
|||
# Test CUDA device
|
||||
test_tensor = torch.zeros(1).cuda()
|
||||
logger.info("CUDA test successful")
|
||||
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
|
||||
model_path = os.path.join(
|
||||
settings.model_dir, settings.pytorch_model_path
|
||||
)
|
||||
cls._device = "cuda"
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA test failed: {e}")
|
||||
|
@ -54,19 +58,35 @@ class TTSBaseModel(ABC):
|
|||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")])
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
|
||||
# Now that model and voices are ready, do warmup
|
||||
try:
|
||||
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "don_quixote.txt")) as f:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"core",
|
||||
"don_quixote.txt",
|
||||
)
|
||||
) as f:
|
||||
warmup_text = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load warmup text: {e}")
|
||||
|
@ -74,6 +94,7 @@ class TTSBaseModel(ABC):
|
|||
|
||||
# Use warmup service after model is fully initialized
|
||||
from .warmup import WarmupService
|
||||
|
||||
warmup = WarmupService()
|
||||
|
||||
# Load and warm up voices
|
||||
|
@ -83,7 +104,9 @@ class TTSBaseModel(ABC):
|
|||
logger.info("Model warm-up complete")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")])
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
|
@ -108,7 +131,9 @@ class TTSBaseModel(ABC):
|
|||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> Tuple[np.ndarray, str]:
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
|
@ -124,7 +149,9 @@ class TTSBaseModel(ABC):
|
|||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_tokens(cls, tokens: List[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
def generate_from_tokens(
|
||||
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
SessionOptions,
|
||||
InferenceSession,
|
||||
GraphOptimizationLevel,
|
||||
)
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from .text_processing import phonemize, tokenize
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
|
@ -41,11 +48,17 @@ class TTSCPUModel(TTSBaseModel):
|
|||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
else:
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
|
@ -63,17 +76,17 @@ class TTSCPUModel(TTSBaseModel):
|
|||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
'CPUExecutionProvider': {
|
||||
'arena_extend_strategy': settings.onnx_arena_extend_strategy,
|
||||
'cpu_memory_arena_cfg': 'cpu:0'
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0",
|
||||
}
|
||||
}
|
||||
|
||||
session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=['CPUExecutionProvider'],
|
||||
provider_options=[provider_options]
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options],
|
||||
)
|
||||
cls._onnx_session = session
|
||||
return session
|
||||
|
@ -96,7 +109,9 @@ class TTSCPUModel(TTSBaseModel):
|
|||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
|
@ -120,7 +135,9 @@ class TTSCPUModel(TTSBaseModel):
|
|||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
|
@ -136,16 +153,15 @@ class TTSCPUModel(TTSBaseModel):
|
|||
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
|
||||
style_input = voicepack[
|
||||
len(tokens) - 2
|
||||
].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(
|
||||
1, speed, dtype=np.float32
|
||||
) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
result = cls._onnx_session.run(
|
||||
None,
|
||||
{
|
||||
'tokens': tokens_input,
|
||||
'style': style_input,
|
||||
'speed': speed_input
|
||||
}
|
||||
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
|
||||
)
|
||||
return result[0]
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
from .text_processing import phonemize, tokenize
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
# def forward(model, tokens, ref_s, speed):
|
||||
|
@ -65,7 +67,7 @@ def forward(model, tokens, ref_s, speed):
|
|||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
|
||||
# Matrix multiplications - reuse unsqueezed tensor
|
||||
|
@ -79,6 +81,7 @@ def forward(model, tokens, ref_s, speed):
|
|||
|
||||
return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy()
|
||||
|
||||
|
||||
# def length_to_mask(lengths):
|
||||
# """Create attention mask from lengths"""
|
||||
# mask = (
|
||||
|
@ -90,17 +93,21 @@ def forward(model, tokens, ref_s, speed):
|
|||
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
||||
# return mask
|
||||
|
||||
|
||||
def length_to_mask(lengths):
|
||||
"""Create attention mask from lengths - possibly optimized version"""
|
||||
max_len = lengths.max()
|
||||
# Create mask directly on the same device as lengths
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
# Avoid type_as by using the correct dtype from the start
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
# Fuse operations using broadcasting
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class TTSGPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
@ -143,7 +150,9 @@ class TTSGPUModel(TTSBaseModel):
|
|||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
|
@ -167,7 +176,9 @@ class TTSGPUModel(TTSBaseModel):
|
|||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import io
|
||||
import aiofiles.os
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
@ -8,13 +7,14 @@ from functools import lru_cache
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import aiofiles.os
|
||||
import scipy.io.wavfile as wavfile
|
||||
from .text_processing import normalize_text, chunker
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .audio import AudioService, AudioNormalizer
|
||||
from .tts_model import TTSModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import chunker, normalize_text
|
||||
|
||||
|
||||
class TTSService:
|
||||
|
@ -26,7 +26,9 @@ class TTSService:
|
|||
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
|
||||
def _load_voice(voice_path: str) -> torch.Tensor:
|
||||
"""Load and cache a voice model"""
|
||||
return torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
|
||||
return torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file"""
|
||||
|
@ -37,7 +39,9 @@ class TTSService:
|
|||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate complete audio and return with processing time"""
|
||||
audio, processing_time = self._generate_audio_internal(text, voice, speed, stitch_long_output)
|
||||
audio, processing_time = self._generate_audio_internal(
|
||||
text, voice, speed, stitch_long_output
|
||||
)
|
||||
return audio, processing_time
|
||||
|
||||
def _generate_audio_internal(
|
||||
|
@ -72,7 +76,9 @@ class TTSService:
|
|||
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
||||
chunks_data.append((chunk, tokens))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not chunks_data:
|
||||
|
@ -82,20 +88,28 @@ class TTSService:
|
|||
audio_chunks = []
|
||||
for chunk, tokens in chunks_data:
|
||||
try:
|
||||
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk: '{chunk}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Concatenate all chunks
|
||||
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
# Process single chunk
|
||||
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
||||
|
@ -109,7 +123,12 @@ class TTSService:
|
|||
raise
|
||||
|
||||
async def generate_audio_stream(
|
||||
self, text: str, voice: str, speed: float, output_format: str = "wav", silent=False
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
speed: float,
|
||||
output_format: str = "wav",
|
||||
silent=False,
|
||||
):
|
||||
"""Generate and yield audio chunks as they're generated for real-time streaming"""
|
||||
try:
|
||||
|
@ -125,7 +144,9 @@ class TTSService:
|
|||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
logger.debug(f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms")
|
||||
logger.debug(
|
||||
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Voice validation and loading
|
||||
voice_start = time.time()
|
||||
|
@ -133,7 +154,9 @@ class TTSService:
|
|||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
voicepack = self._load_voice(voice_path)
|
||||
logger.debug(f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms")
|
||||
logger.debug(
|
||||
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Process chunks as they're generated
|
||||
is_first = True
|
||||
|
@ -149,7 +172,9 @@ class TTSService:
|
|||
try:
|
||||
# Process text and generate audio
|
||||
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
|
||||
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
|
||||
if chunk_audio is not None:
|
||||
# Convert chunk with proper header handling
|
||||
|
@ -159,7 +184,7 @@ class TTSService:
|
|||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=(next_chunk is None) # Last if no next chunk
|
||||
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
||||
)
|
||||
|
||||
yield chunk_bytes
|
||||
|
@ -168,7 +193,9 @@ class TTSService:
|
|||
logger.error(f"No audio generated for chunk: '{current_chunk}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
|
||||
)
|
||||
|
||||
current_chunk = next_chunk # Move to next chunk
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .tts_service import TTSService
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
|
@ -22,18 +23,19 @@ class WarmupService:
|
|||
"""Load and cache voices up to LRU limit"""
|
||||
# Get all voices sorted by filename length (shorter names first, usually base voices)
|
||||
voice_files = sorted(
|
||||
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")],
|
||||
key=len
|
||||
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
|
||||
)
|
||||
|
||||
n_voices_cache=1
|
||||
n_voices_cache = 1
|
||||
loaded_voices = []
|
||||
for voice_file in voice_files[:n_voices_cache]:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
|
||||
# load using service, lru cache
|
||||
voicepack = self.tts_service._load_voice(voice_path)
|
||||
loaded_voices.append((voice_file[:-3], voicepack)) # Store name and tensor
|
||||
loaded_voices.append(
|
||||
(voice_file[:-3], voicepack)
|
||||
) # Store name and tensor
|
||||
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
|
||||
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
|
||||
except Exception as e:
|
||||
|
@ -41,17 +43,16 @@ class WarmupService:
|
|||
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
|
||||
return loaded_voices
|
||||
|
||||
async def warmup_voices(self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]):
|
||||
async def warmup_voices(
|
||||
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
|
||||
):
|
||||
"""Warm up voice inference and streaming"""
|
||||
n_warmups = 1
|
||||
for voice_name, _ in loaded_voices[:n_warmups]:
|
||||
try:
|
||||
logger.info(f"Running warmup inference on voice {voice_name}")
|
||||
async for _ in self.tts_service.generate_audio_stream(
|
||||
warmup_text,
|
||||
voice_name,
|
||||
1.0,
|
||||
"pcm"
|
||||
warmup_text, voice_name, 1.0, "pcm"
|
||||
):
|
||||
pass # Process all chunks to properly warm up
|
||||
logger.info(f"Completed warmup for voice {voice_name}")
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from enum import Enum
|
||||
from typing import Literal, Union, List
|
||||
from typing import List, Union, Literal
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
|
||||
class VoiceCombineRequest(BaseModel):
|
||||
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
|
||||
|
||||
voices: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine"
|
||||
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
text: str
|
||||
language: str = "a" # Default to American English
|
||||
|
||||
|
||||
class PhonemeResponse(BaseModel):
|
||||
phonemes: str
|
||||
tokens: list[int]
|
||||
|
||||
|
||||
class GenerateFromPhonemesRequest(BaseModel):
|
||||
phonemes: str
|
||||
voice: str = Field(..., description="Voice ID to use for generation")
|
||||
speed: float = Field(default=1.0, ge=0.1, le=5.0, description="Speed factor for generation")
|
||||
speed: float = Field(
|
||||
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
||||
)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import aiofiles.threadpool
|
||||
|
||||
|
@ -37,6 +37,7 @@ mock_torch = Mock()
|
|||
mock_torch.cuda = Mock()
|
||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
||||
|
||||
|
||||
# Create a mock tensor class that supports basic operations
|
||||
class MockTensor:
|
||||
def __init__(self, data):
|
||||
|
@ -46,7 +47,7 @@ class MockTensor:
|
|||
elif isinstance(data, MockTensor):
|
||||
self.shape = data.shape
|
||||
else:
|
||||
self.shape = getattr(data, 'shape', [1])
|
||||
self.shape = getattr(data, "shape", [1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
|
@ -91,9 +92,12 @@ class MockTensor:
|
|||
def type_as(self, other):
|
||||
return self
|
||||
|
||||
|
||||
# Add tensor operations to mock torch
|
||||
mock_torch.tensor = lambda x: MockTensor(x)
|
||||
mock_torch.zeros = lambda *args: MockTensor([0] * (args[0] if isinstance(args[0], int) else args[0][0]))
|
||||
mock_torch.zeros = lambda *args: MockTensor(
|
||||
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
|
||||
)
|
||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
||||
|
||||
|
@ -176,7 +180,9 @@ def mock_tts_service(monkeypatch):
|
|||
|
||||
# Mock TTSModel.generate_from_tokens since we call it directly
|
||||
mock_generate = Mock(return_value=np.zeros(48000))
|
||||
monkeypatch.setattr("api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate)
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate
|
||||
)
|
||||
|
||||
return mock_service
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
"""Tests for AudioService"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.src.services.audio import AudioService, AudioNormalizer
|
||||
|
||||
|
@ -10,10 +11,11 @@ from api.src.services.audio import AudioService, AudioNormalizer
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch('api.src.services.audio.settings') as mock_settings:
|
||||
with patch("api.src.services.audio.settings") as mock_settings:
|
||||
mock_settings.gap_trim_ms = 250
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Generate a simple sine wave for testing"""
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
"""Tests for text chunking service"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.src.services.text_processing import chunker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch('api.src.services.text_processing.chunker.settings') as mock_settings:
|
||||
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
|
||||
mock_settings.max_chunk_size = 300
|
||||
yield mock_settings
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..src.main import app
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Create async client fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client():
|
||||
|
@ -28,10 +29,12 @@ def mock_tts_service(monkeypatch):
|
|||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Create async mocks
|
||||
mock_service.list_voices = AsyncMock(return_value=[
|
||||
mock_service.list_voices = AsyncMock(
|
||||
return_value=[
|
||||
"af",
|
||||
"bm_lewis",
|
||||
"bf_isabella",
|
||||
|
@ -41,7 +44,8 @@ def mock_tts_service(monkeypatch):
|
|||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_george",
|
||||
])
|
||||
]
|
||||
)
|
||||
mock_service.combine_voices = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.TTSService",
|
||||
|
@ -54,9 +58,7 @@ def mock_tts_service(monkeypatch):
|
|||
def mock_audio_service(monkeypatch):
|
||||
mock_service = Mock()
|
||||
mock_service.convert_audio.return_value = b"converted mock audio data"
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.AudioService", mock_service
|
||||
)
|
||||
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService", mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
|
@ -68,7 +70,9 @@ def test_health_check():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client):
|
||||
async def test_openai_speech_endpoint(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test the OpenAI-compatible speech endpoint"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
|
@ -76,7 +80,7 @@ async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, asyn
|
|||
"voice": "bm_lewis",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 200
|
||||
|
@ -97,7 +101,7 @@ async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
|
|||
"voice": "invalid_voice",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400 # Bad request
|
||||
|
@ -113,7 +117,7 @@ async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
|
|||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": -1.0, # Invalid speed
|
||||
"stream": False # Explicitly disable streaming
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
@ -129,7 +133,7 @@ async def test_openai_speech_generation_error(mock_tts_service, async_client):
|
|||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 500
|
||||
|
@ -159,7 +163,9 @@ async def test_combine_voices_string_success(mock_tts_service, async_client):
|
|||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||
mock_tts_service.combine_voices.assert_called_once_with(
|
||||
voices=["af_bella", "af_sarah"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -184,7 +190,9 @@ async def test_combine_voices_empty_list(mock_tts_service, async_client):
|
|||
async def test_combine_voices_error(mock_tts_service, async_client):
|
||||
"""Test error handling in voice combination"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed"))
|
||||
mock_tts_service.combine_voices = AsyncMock(
|
||||
side_effect=Exception("Combination failed")
|
||||
)
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 500
|
||||
|
@ -192,7 +200,9 @@ async def test_combine_voices_error(mock_tts_service, async_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client):
|
||||
async def test_speech_with_combined_voice(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test speech generation with combined voice using + syntax"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
|
@ -202,7 +212,7 @@ async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service,
|
|||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
@ -213,12 +223,14 @@ async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service,
|
|||
text="Hello world",
|
||||
voice="af_bella_af_sarah",
|
||||
speed=1.0,
|
||||
stitch_long_output=True
|
||||
stitch_long_output=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client):
|
||||
async def test_speech_with_whitespace_in_voice(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test speech generation with whitespace in voice combination"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
|
@ -228,14 +240,16 @@ async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_serv
|
|||
"voice": " af_bella + af_sarah ",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||
mock_tts_service.combine_voices.assert_called_once_with(
|
||||
voices=["af_bella", "af_sarah"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -247,7 +261,7 @@ async def test_speech_with_empty_voice_combination(mock_tts_service, async_clien
|
|||
"voice": "+",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
@ -264,7 +278,7 @@ async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client
|
|||
"voice": "invalid+combination",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
@ -282,18 +296,21 @@ async def test_speech_streaming_with_combined_voice(mock_tts_service, async_clie
|
|||
"input": "Hello world",
|
||||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"mp3header", b"mp3data"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
@ -308,18 +325,21 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
|
|||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
@ -333,18 +353,21 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
|
|||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"mp3header", b"mp3data"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
@ -359,18 +382,21 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
|
|||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for FastAPI application"""
|
||||
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
@ -32,6 +32,7 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
# Create async mock
|
||||
async def async_setup():
|
||||
return 3
|
||||
|
||||
mock_tts_model.setup = MagicMock()
|
||||
mock_tts_model.setup.side_effect = async_setup
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
@ -90,6 +91,7 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
|
|||
# Create async mock
|
||||
async def async_setup():
|
||||
return 2
|
||||
|
||||
mock_tts_model.setup = MagicMock()
|
||||
mock_tts_model.setup.side_effect = async_setup
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
|
|
@ -1,43 +1,88 @@
|
|||
"""Tests for text normalization service"""
|
||||
|
||||
import pytest
|
||||
|
||||
from api.src.services.text_processing.normalizer import normalize_text
|
||||
|
||||
|
||||
def test_url_protocols():
|
||||
"""Test URL protocol handling"""
|
||||
assert normalize_text("Check out https://example.com") == "Check out https example dot com"
|
||||
assert (
|
||||
normalize_text("Check out https://example.com")
|
||||
== "Check out https example dot com"
|
||||
)
|
||||
assert normalize_text("Visit http://site.com") == "Visit http site dot com"
|
||||
assert normalize_text("Go to https://test.org/path") == "Go to https test dot org slash path"
|
||||
assert (
|
||||
normalize_text("Go to https://test.org/path")
|
||||
== "Go to https test dot org slash path"
|
||||
)
|
||||
|
||||
|
||||
def test_url_www():
|
||||
"""Test www prefix handling"""
|
||||
assert normalize_text("Go to www.example.com") == "Go to www example dot com"
|
||||
assert normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs"
|
||||
assert normalize_text("Check www.site.com?q=test") == "Check www site dot com question-mark q equals test"
|
||||
assert (
|
||||
normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs"
|
||||
)
|
||||
assert (
|
||||
normalize_text("Check www.site.com?q=test")
|
||||
== "Check www site dot com question-mark q equals test"
|
||||
)
|
||||
|
||||
|
||||
def test_url_localhost():
|
||||
"""Test localhost URL handling"""
|
||||
assert normalize_text("Running on localhost:7860") == "Running on localhost colon 78 60"
|
||||
assert normalize_text("Server at localhost:8080/api") == "Server at localhost colon 80 80 slash api"
|
||||
assert normalize_text("Test localhost:3000/test?v=1") == "Test localhost colon 3000 slash test question-mark v equals 1"
|
||||
assert (
|
||||
normalize_text("Running on localhost:7860")
|
||||
== "Running on localhost colon 78 60"
|
||||
)
|
||||
assert (
|
||||
normalize_text("Server at localhost:8080/api")
|
||||
== "Server at localhost colon 80 80 slash api"
|
||||
)
|
||||
assert (
|
||||
normalize_text("Test localhost:3000/test?v=1")
|
||||
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
||||
)
|
||||
|
||||
|
||||
def test_url_ip_addresses():
|
||||
"""Test IP address URL handling"""
|
||||
assert normalize_text("Access 0.0.0.0:9090/test") == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
||||
assert normalize_text("API at 192.168.1.1:8000") == "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
||||
assert (
|
||||
normalize_text("Access 0.0.0.0:9090/test")
|
||||
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
||||
)
|
||||
assert (
|
||||
normalize_text("API at 192.168.1.1:8000")
|
||||
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
||||
)
|
||||
assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1"
|
||||
|
||||
|
||||
def test_url_raw_domains():
|
||||
"""Test raw domain handling"""
|
||||
assert normalize_text("Visit google.com/search") == "Visit google dot com slash search"
|
||||
assert normalize_text("Go to example.com/path?q=test") == "Go to example dot com slash path question-mark q equals test"
|
||||
assert (
|
||||
normalize_text("Visit google.com/search") == "Visit google dot com slash search"
|
||||
)
|
||||
assert (
|
||||
normalize_text("Go to example.com/path?q=test")
|
||||
== "Go to example dot com slash path question-mark q equals test"
|
||||
)
|
||||
assert normalize_text("Check docs.test.com") == "Check docs dot test dot com"
|
||||
|
||||
|
||||
def test_url_email_addresses():
|
||||
"""Test email address handling"""
|
||||
assert normalize_text("Email me at user@example.com") == "Email me at user at example dot com"
|
||||
assert (
|
||||
normalize_text("Email me at user@example.com")
|
||||
== "Email me at user at example dot com"
|
||||
)
|
||||
assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org"
|
||||
assert normalize_text("Send to test.user@site.com") == "Send to test dot user at site dot com"
|
||||
assert (
|
||||
normalize_text("Send to test.user@site.com")
|
||||
== "Send to test dot user at site dot com"
|
||||
)
|
||||
|
||||
|
||||
def test_non_url_text():
|
||||
"""Test that non-URL text is unaffected"""
|
||||
|
|
|
@ -1,33 +1,36 @@
|
|||
"""Tests for text processing endpoints"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
import numpy as np
|
||||
|
||||
from ..src.main import app
|
||||
from .conftest import MockTTSModel
|
||||
from ..src.main import app
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phonemize_endpoint(async_client):
|
||||
"""Test phoneme generation endpoint"""
|
||||
with patch('api.src.routers.text_processing.phonemize') as mock_phonemize, \
|
||||
patch('api.src.routers.text_processing.tokenize') as mock_tokenize:
|
||||
|
||||
with patch("api.src.routers.text_processing.phonemize") as mock_phonemize, patch(
|
||||
"api.src.routers.text_processing.tokenize"
|
||||
) as mock_tokenize:
|
||||
# Setup mocks
|
||||
mock_phonemize.return_value = "həlˈoʊ"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
# Test request
|
||||
response = await async_client.post("/text/phonemize", json={
|
||||
"text": "hello",
|
||||
"language": "a"
|
||||
})
|
||||
response = await async_client.post(
|
||||
"/text/phonemize", json={"text": "hello", "language": "a"}
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
|
@ -35,46 +38,55 @@ async def test_phonemize_endpoint(async_client):
|
|||
assert result["phonemes"] == "həlˈoʊ"
|
||||
assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phonemize_empty_text(async_client):
|
||||
"""Test phoneme generation with empty text"""
|
||||
response = await async_client.post("/text/phonemize", json={
|
||||
"text": "",
|
||||
"language": "a"
|
||||
})
|
||||
response = await async_client.post(
|
||||
"/text/phonemize", json={"text": "", "language": "a"}
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "error" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes(async_client, mock_tts_service, mock_audio_service):
|
||||
async def test_generate_from_phonemes(
|
||||
async_client, mock_tts_service, mock_audio_service
|
||||
):
|
||||
"""Test audio generation from phonemes"""
|
||||
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service):
|
||||
response = await async_client.post("/text/generate_from_phonemes", json={
|
||||
"phonemes": "həlˈoʊ",
|
||||
"voice": "af_bella",
|
||||
"speed": 1.0
|
||||
})
|
||||
with patch(
|
||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
assert (
|
||||
response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
)
|
||||
assert response.content == b"mock audio data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service):
|
||||
"""Test audio generation with invalid voice"""
|
||||
mock_tts_service._get_voice_path.return_value = None
|
||||
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service):
|
||||
response = await async_client.post("/text/generate_from_phonemes", json={
|
||||
"phonemes": "həlˈoʊ",
|
||||
"voice": "invalid_voice",
|
||||
"speed": 1.0
|
||||
})
|
||||
with patch(
|
||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "invalid_voice", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Voice not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
|
||||
"""Test audio generation with invalid speed"""
|
||||
|
@ -82,25 +94,29 @@ async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
|
|||
mock_model = Mock()
|
||||
mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000))
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_instance", Mock(return_value=mock_model))
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_instance",
|
||||
Mock(return_value=mock_model),
|
||||
)
|
||||
|
||||
response = await async_client.post("/text/generate_from_phonemes", json={
|
||||
"phonemes": "həlˈoʊ",
|
||||
"voice": "af_bella",
|
||||
"speed": -1.0
|
||||
})
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": -1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
|
||||
"""Test audio generation with empty phonemes"""
|
||||
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service):
|
||||
response = await async_client.post("/text/generate_from_phonemes", json={
|
||||
"phonemes": "",
|
||||
"voice": "af_bella",
|
||||
"speed": 1.0
|
||||
})
|
||||
with patch(
|
||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid request" in response.json()["detail"]["error"]
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# Base Model Tests
|
||||
def test_get_device_error():
|
||||
|
@ -16,14 +19,17 @@ def test_get_device_error():
|
|||
with pytest.raises(RuntimeError, match="Model not initialized"):
|
||||
TTSBaseModel.get_device()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('torch.cuda.is_available')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.join')
|
||||
@patch('os.listdir')
|
||||
@patch('torch.load')
|
||||
@patch('torch.save')
|
||||
async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.path.join")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
async def test_setup_cuda_available(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA available"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = True
|
||||
|
@ -35,7 +41,7 @@ async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_joi
|
|||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.bert = MagicMock()
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1,2,3]))
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
|
||||
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
# Mock initialize to return our mock model
|
||||
|
@ -46,14 +52,17 @@ async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_joi
|
|||
assert TTSBaseModel._device == "cuda"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('torch.cuda.is_available')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.join')
|
||||
@patch('os.listdir')
|
||||
@patch('torch.load')
|
||||
@patch('torch.save')
|
||||
async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.path.join")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
async def test_setup_cuda_unavailable(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA unavailable"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = False
|
||||
|
@ -65,7 +74,7 @@ async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_j
|
|||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.bert = MagicMock()
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1,2,3]))
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
|
||||
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
# Mock initialize to return our mock model
|
||||
|
@ -76,15 +85,18 @@ async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_j
|
|||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
# CPU Model Tests
|
||||
def test_cpu_initialize_missing_model():
|
||||
"""Test CPU initialize with missing model"""
|
||||
TTSCPUModel._onnx_session = None # Reset the session
|
||||
with patch('os.path.exists', return_value=False), \
|
||||
patch('onnxruntime.InferenceSession', return_value=None):
|
||||
with patch("os.path.exists", return_value=False), patch(
|
||||
"onnxruntime.InferenceSession", return_value=None
|
||||
):
|
||||
result = TTSCPUModel.initialize("dummy_dir")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_cpu_generate_uninitialized():
|
||||
"""Test CPU generate methods with uninitialized model"""
|
||||
TTSCPUModel._onnx_session = None
|
||||
|
@ -93,13 +105,14 @@ def test_cpu_generate_uninitialized():
|
|||
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
|
||||
TTSCPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
|
||||
TTSCPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
|
||||
|
||||
|
||||
def test_cpu_process_text():
|
||||
"""Test CPU process_text functionality"""
|
||||
with patch('api.src.services.tts_cpu.phonemize') as mock_phonemize, \
|
||||
patch('api.src.services.tts_cpu.tokenize') as mock_tokenize:
|
||||
|
||||
with patch("api.src.services.tts_cpu.phonemize") as mock_phonemize, patch(
|
||||
"api.src.services.tts_cpu.tokenize"
|
||||
) as mock_tokenize:
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
|
@ -107,8 +120,9 @@ def test_cpu_process_text():
|
|||
assert phonemes == "test phonemes"
|
||||
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
|
||||
|
||||
|
||||
# GPU Model Tests
|
||||
@patch('torch.cuda.is_available')
|
||||
@patch("torch.cuda.is_available")
|
||||
def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
|
||||
"""Test GPU initialize with CUDA unavailable"""
|
||||
mock_cuda_available.return_value = False
|
||||
|
@ -117,14 +131,14 @@ def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
|
|||
result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
|
||||
assert result is None
|
||||
|
||||
@patch('api.src.services.tts_gpu.length_to_mask')
|
||||
|
||||
@patch("api.src.services.tts_gpu.length_to_mask")
|
||||
def test_gpu_length_to_mask(mock_length_to_mask):
|
||||
"""Test length_to_mask function"""
|
||||
# Setup mock return value
|
||||
expected_mask = torch.tensor([
|
||||
[False, False, False, True, True],
|
||||
[False, False, False, False, False]
|
||||
])
|
||||
expected_mask = torch.tensor(
|
||||
[[False, False, False, True, True], [False, False, False, False, False]]
|
||||
)
|
||||
mock_length_to_mask.return_value = expected_mask
|
||||
|
||||
# Call function with test input
|
||||
|
@ -135,6 +149,7 @@ def test_gpu_length_to_mask(mock_length_to_mask):
|
|||
mock_length_to_mask.assert_called_once()
|
||||
assert torch.equal(mask, expected_mask)
|
||||
|
||||
|
||||
def test_gpu_generate_uninitialized():
|
||||
"""Test GPU generate methods with uninitialized model"""
|
||||
TTSGPUModel._instance = None
|
||||
|
@ -143,13 +158,14 @@ def test_gpu_generate_uninitialized():
|
|||
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="GPU model not initialized"):
|
||||
TTSGPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
|
||||
TTSGPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
|
||||
|
||||
|
||||
def test_gpu_process_text():
|
||||
"""Test GPU process_text functionality"""
|
||||
with patch('api.src.services.tts_gpu.phonemize') as mock_phonemize, \
|
||||
patch('api.src.services.tts_gpu.tokenize') as mock_tokenize:
|
||||
|
||||
with patch("api.src.services.tts_gpu.phonemize") as mock_phonemize, patch(
|
||||
"api.src.services.tts_gpu.tokenize"
|
||||
) as mock_tokenize:
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
|
|
|
@ -9,10 +9,10 @@ import pytest
|
|||
from onnxruntime import InferenceSession
|
||||
|
||||
from api.src.core.config import settings
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -25,8 +25,13 @@ def tts_service(monkeypatch):
|
|||
|
||||
# Set up model instance
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_instance", MagicMock(return_value=mock_model))
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu"))
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_instance",
|
||||
MagicMock(return_value=mock_model),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu")
|
||||
)
|
||||
|
||||
return TTSService()
|
||||
|
||||
|
@ -51,6 +56,7 @@ def test_audio_to_bytes(tts_service, sample_audio):
|
|||
@pytest.mark.asyncio
|
||||
async def test_list_voices(tts_service):
|
||||
"""Test listing available voices"""
|
||||
|
||||
# Override list_voices for testing
|
||||
# # TODO:
|
||||
# Whatever aiofiles does here pathing aiofiles vs aiofiles.os
|
||||
|
@ -58,6 +64,7 @@ async def test_list_voices(tts_service):
|
|||
# Cheating the test as it seems to work in the real world (for now)
|
||||
async def mock_list_voices():
|
||||
return ["voice1", "voice2"]
|
||||
|
||||
tts_service.list_voices = mock_list_voices
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
|
@ -69,10 +76,12 @@ async def test_list_voices(tts_service):
|
|||
@pytest.mark.asyncio
|
||||
async def test_list_voices_error(tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
|
||||
# Override list_voices for testing
|
||||
# TODO: See above.
|
||||
async def mock_list_voices():
|
||||
return []
|
||||
|
||||
tts_service.list_voices = mock_list_voices
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
|
@ -124,10 +133,11 @@ def test_generate_audio_empty_text(tts_service):
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch('api.src.services.text_processing.chunker.settings') as mock_settings:
|
||||
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
|
||||
mock_settings.max_chunk_size = 300
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
|
@ -150,7 +160,10 @@ def test_generate_audio_phonemize_error(
|
|||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
||||
mock_instance.return_value = (
|
||||
mock_generate,
|
||||
"cpu",
|
||||
) # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
|
@ -184,7 +197,10 @@ def test_generate_audio_error(
|
|||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
||||
mock_instance.return_value = (
|
||||
mock_generate,
|
||||
"cpu",
|
||||
) # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
|
@ -205,12 +221,11 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
|
|||
async def test_combine_voices(tts_service):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks for torch operations
|
||||
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
||||
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
|
||||
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
|
||||
patch('torch.save'), \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
with patch("torch.load", return_value=torch.tensor([1.0, 2.0])), patch(
|
||||
"torch.stack", return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
), patch("torch.mean", return_value=torch.tensor([2.0, 3.0])), patch(
|
||||
"torch.save"
|
||||
), patch("os.path.exists", return_value=True):
|
||||
# Test combining two voices
|
||||
result = await tts_service.combine_voices(["voice1", "voice2"])
|
||||
|
||||
|
|
|
@ -166,7 +166,7 @@ def measure_first_token_openai(
|
|||
|
||||
def main():
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prefix='cpu'
|
||||
prefix = "cpu"
|
||||
# Run requests benchmark
|
||||
print("\n=== Running Direct Requests Benchmark ===")
|
||||
run_benchmark(
|
||||
|
@ -176,7 +176,7 @@ def main():
|
|||
output_plots_dir=os.path.join(script_dir, "output_plots"),
|
||||
suffix="_stream",
|
||||
plot_title_suffix="(Streaming)",
|
||||
prefix=prefix
|
||||
prefix=prefix,
|
||||
)
|
||||
# Run OpenAI benchmark
|
||||
print("\n=== Running OpenAI Library Benchmark ===")
|
||||
|
@ -187,7 +187,7 @@ def main():
|
|||
output_plots_dir=os.path.join(script_dir, "output_plots"),
|
||||
suffix="_stream_openai",
|
||||
plot_title_suffix="(OpenAI Streaming)",
|
||||
prefix=prefix
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -149,19 +149,19 @@ def run_benchmark(
|
|||
result["run_number"] = i + 1
|
||||
|
||||
# Handle time to first audio
|
||||
first_chunk = result.get('time_to_first_chunk')
|
||||
first_chunk = result.get("time_to_first_chunk")
|
||||
print(
|
||||
f"Time to First Audio: {f'{first_chunk:.3f}s' if first_chunk is not None else 'N/A'}"
|
||||
)
|
||||
|
||||
# Handle total time
|
||||
total_time = result.get('total_time')
|
||||
total_time = result.get("total_time")
|
||||
print(
|
||||
f"Time to Save Complete: {f'{total_time:.3f}s' if total_time is not None else 'N/A'}"
|
||||
)
|
||||
|
||||
# Handle audio length
|
||||
audio_length = result.get('audio_length')
|
||||
audio_length = result.get("audio_length")
|
||||
print(
|
||||
f"Audio length: {f'{audio_length:.3f}s' if audio_length is not None else 'N/A'}"
|
||||
)
|
||||
|
@ -191,10 +191,18 @@ def run_benchmark(
|
|||
|
||||
# Print paths
|
||||
print("\nResults and plots saved to:")
|
||||
print(f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}")
|
||||
print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}")
|
||||
print(f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}")
|
||||
print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}")
|
||||
print(
|
||||
f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}"
|
||||
)
|
||||
print(
|
||||
f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}"
|
||||
)
|
||||
print(
|
||||
f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}"
|
||||
)
|
||||
print(
|
||||
f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}"
|
||||
)
|
||||
|
||||
# Print silence check summary
|
||||
if silent_files:
|
||||
|
|
|
@ -1,42 +1,39 @@
|
|||
|
||||
#!/usr/bin/env rye run python
|
||||
|
||||
# %%
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# gets OPENAI_API_KEY from your environment variables
|
||||
openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||
openai = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed-for-local")
|
||||
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
stream_to_speakers()
|
||||
|
||||
# Create text-to-speech audio file
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
) as response:
|
||||
response.stream_to_file(speech_file_path)
|
||||
|
||||
|
||||
|
||||
def stream_to_speakers() -> None:
|
||||
import pyaudio
|
||||
|
||||
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
|
||||
player_stream = pyaudio.PyAudio().open(
|
||||
format=pyaudio.paInt16, channels=1, rate=24000, output=True
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_0p0_n2p0",
|
||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earth’s surface""",
|
||||
voice=VOICE,
|
||||
response_format="mp3", # similar to WAV, but without a header chunk at the start.
|
||||
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension""",
|
||||
) as response:
|
||||
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||
for chunk in response.iter_bytes(chunk_size=1024):
|
||||
|
@ -47,3 +44,5 @@ def stream_to_speakers() -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# %%
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import requests
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
# Get the directory this script is in
|
||||
SCRIPT_DIR = Path(__file__).parent.absolute()
|
||||
|
||||
|
||||
def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
|
||||
"""Get phonemes and tokens for input text.
|
||||
|
||||
|
@ -17,16 +19,10 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
|
|||
Tuple of (phonemes string, token list)
|
||||
"""
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"text": text,
|
||||
"language": language
|
||||
}
|
||||
payload = {"text": text, "language": language}
|
||||
|
||||
# Make POST request to the phonemize endpoint
|
||||
response = requests.post(
|
||||
"http://localhost:8880/text/phonemize",
|
||||
json=payload
|
||||
)
|
||||
response = requests.post("http://localhost:8880/text/phonemize", json=payload)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
|
@ -35,7 +31,10 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
|
|||
result = response.json()
|
||||
return result["phonemes"], result["tokens"]
|
||||
|
||||
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed: float = 1.0) -> Optional[bytes]:
|
||||
|
||||
def generate_audio_from_phonemes(
|
||||
phonemes: str, voice: str = "af_bella", speed: float = 1.0
|
||||
) -> Optional[bytes]:
|
||||
"""Generate audio from phonemes.
|
||||
|
||||
Args:
|
||||
|
@ -47,16 +46,11 @@ def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed:
|
|||
WAV audio bytes if successful, None if failed
|
||||
"""
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"phonemes": phonemes,
|
||||
"voice": voice,
|
||||
"speed": speed
|
||||
}
|
||||
payload = {"phonemes": phonemes, "voice": voice, "speed": speed}
|
||||
|
||||
# Make POST request to generate audio
|
||||
response = requests.post(
|
||||
"http://localhost:8880/text/generate_from_phonemes",
|
||||
json=payload
|
||||
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
|
@ -64,6 +58,7 @@ def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed:
|
|||
|
||||
return response.content
|
||||
|
||||
|
||||
def main():
|
||||
# Example texts to convert
|
||||
examples = [
|
||||
|
@ -71,7 +66,7 @@ def main():
|
|||
"How are you today? I am doing reasonably well, thank you for asking",
|
||||
"""This is a test of the phoneme generation system. Do not be alarmed.
|
||||
This is only a test. If this were a real phoneme emergency, '
|
||||
you would be instructed to a phoneme shelter in your area."""
|
||||
you would be instructed to a phoneme shelter in your area.""",
|
||||
]
|
||||
|
||||
print("Generating phonemes and audio for example texts...\n")
|
||||
|
@ -104,5 +99,6 @@ def main():
|
|||
except requests.RequestException as e:
|
||||
print(f"Error: {e}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
import requests
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import time
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import sounddevice as sd
|
||||
|
||||
|
||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
||||
"""Stream TTS audio and play it back in real-time"""
|
||||
|
||||
|
@ -26,7 +28,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
channels=1,
|
||||
dtype=np.int16,
|
||||
blocksize=1024, # Buffer size in samples
|
||||
latency='low' # Request low latency
|
||||
latency="low", # Request low latency
|
||||
)
|
||||
stream.start()
|
||||
|
||||
|
@ -39,16 +41,18 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
"input": text,
|
||||
"voice": voice,
|
||||
"response_format": "pcm",
|
||||
"stream": True
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
timeout=1800
|
||||
timeout=1800,
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(f"Request started successfully after {time.time() - start_time:.2f}s")
|
||||
|
||||
# Process streaming response with smaller chunks for lower latency
|
||||
for chunk in response.iter_content(chunk_size=512): # 512 bytes = 256 samples at 16-bit
|
||||
for chunk in response.iter_content(
|
||||
chunk_size=512
|
||||
): # 512 bytes = 256 samples at 16-bit
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
|
@ -56,7 +60,9 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
# Handle first chunk
|
||||
if not audio_started:
|
||||
first_chunk_time = time.time()
|
||||
print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s")
|
||||
print(
|
||||
f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s"
|
||||
)
|
||||
print(f"First chunk size: {len(chunk)} bytes")
|
||||
audio_started = True
|
||||
|
||||
|
@ -70,7 +76,9 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
# Log progress every 10 chunks
|
||||
if chunk_count % 10 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed")
|
||||
print(
|
||||
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"
|
||||
)
|
||||
|
||||
# Final stats
|
||||
total_time = time.time() - start_time
|
||||
|
@ -83,7 +91,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
# Save as WAV file
|
||||
if output_file:
|
||||
print(f"\nWriting audio to {output_file}")
|
||||
with wave.open(output_file, 'wb') as wav_file:
|
||||
with wave.open(output_file, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit)
|
||||
wav_file.setframerate(sample_rate)
|
||||
|
@ -103,10 +111,13 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
|||
stream.stop()
|
||||
stream.close()
|
||||
|
||||
|
||||
def main():
|
||||
# Load sample text from HG Wells
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
wells_path = os.path.join(script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt")
|
||||
wells_path = os.path.join(
|
||||
script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt"
|
||||
)
|
||||
output_path = os.path.join(script_dir, "output.wav")
|
||||
|
||||
with open(wells_path, "r", encoding="utf-8") as f:
|
||||
|
@ -121,5 +132,6 @@ def main():
|
|||
|
||||
play_streaming_tts(text, output_file=output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Reference in a new issue