Ruff format + fix

This commit is contained in:
remsky 2025-01-09 18:41:44 -07:00
parent f6e3afa14c
commit e8c1284032
31 changed files with 927 additions and 624 deletions

View file

@ -2,8 +2,8 @@
FastAPI OpenAI Compatible API FastAPI OpenAI Compatible API
""" """
from contextlib import asynccontextmanager
import sys import sys
from contextlib import asynccontextmanager
import uvicorn import uvicorn
from loguru import logger from loguru import logger
@ -12,9 +12,9 @@ from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings from .core.config import settings
from .services.tts_model import TTSModel from .services.tts_model import TTSModel
from .routers.development import router as dev_router
from .services.tts_service import TTSService from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router from .routers.openai_compatible import router as openai_router
from .routers.development import router as dev_router
def setup_logger(): def setup_logger():
@ -24,25 +24,21 @@ def setup_logger():
{ {
"sink": sys.stdout, "sink": sys.stdout,
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | " "format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
"{level: <8} | " "{level: <8} | "
"{message}", "{message}",
"colorize": True, "colorize": True,
"level": "INFO" "level": "INFO",
}, },
], ],
} }
# Remove default logger
logger.remove() logger.remove()
# Add our custom logger
logger.configure(**config) logger.configure(**config)
# Override error colors
logger.level("ERROR", color="<red>") logger.level("ERROR", color="<red>")
# Configure logger # Configure logger
setup_logger() setup_logger()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization""" """Lifespan context manager for model initialization"""
@ -52,7 +48,7 @@ async def lifespan(app: FastAPI):
voicepack_count = await TTSModel.setup() voicepack_count = await TTSModel.setup()
# boundary = "█████╗"*9 # boundary = "█████╗"*9
boundary = "" * 24 boundary = "" * 24
startup_msg =f""" startup_msg = f"""
{boundary} {boundary}

View file

@ -1,24 +1,30 @@
from typing import List from typing import List
from loguru import logger
from fastapi import APIRouter, HTTPException, Depends, Response
from ..structures.text_schemas import PhonemeRequest, PhonemeResponse, GenerateFromPhonemesRequest
from ..services.text_processing import phonemize, tokenize
from ..services.audio import AudioService
from ..services.tts_service import TTSService
from ..services.tts_model import TTSModel
import numpy as np import numpy as np
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
PhonemeRequest,
PhonemeResponse,
GenerateFromPhonemesRequest,
)
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"]) router = APIRouter(tags=["text processing"])
def get_tts_service() -> TTSService: def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance""" """Dependency to get TTSService instance"""
return TTSService() return TTSService()
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"]) @router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@router.post("/dev/phonemize", response_model=PhonemeResponse) @router.post("/dev/phonemize", response_model=PhonemeResponse)
async def phonemize_text( async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
request: PhonemeRequest
) -> PhonemeResponse:
"""Convert text to phonemes and tokens """Convert text to phonemes and tokens
Args: Args:
@ -41,28 +47,24 @@ async def phonemize_text(
tokens = tokenize(phonemes) tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse( return PhonemeResponse(phonemes=phonemes, tokens=tokens)
phonemes=phonemes,
tokens=tokens
)
except ValueError as e: except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}") logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Error in phoneme generation: {str(e)}") logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
@router.post("/text/generate_from_phonemes", tags=["deprecated"]) @router.post("/text/generate_from_phonemes", tags=["deprecated"])
@router.post("/dev/generate_from_phonemes") @router.post("/dev/generate_from_phonemes")
async def generate_from_phonemes( async def generate_from_phonemes(
request: GenerateFromPhonemesRequest, request: GenerateFromPhonemesRequest,
tts_service: TTSService = Depends(get_tts_service) tts_service: TTSService = Depends(get_tts_service),
) -> Response: ) -> Response:
"""Generate audio directly from phonemes """Generate audio directly from phonemes
@ -77,7 +79,7 @@ async def generate_from_phonemes(
if not request.phonemes: if not request.phonemes:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"} detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
) )
# Validate voice exists # Validate voice exists
@ -85,7 +87,10 @@ async def generate_from_phonemes(
if not voice_path: if not voice_path:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={"error": "Invalid request", "message": f"Voice not found: {request.voice}"} detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
) )
try: try:
@ -101,12 +106,7 @@ async def generate_from_phonemes(
# Convert to WAV bytes # Convert to WAV bytes
wav_bytes = AudioService.convert_audio( wav_bytes = AudioService.convert_audio(
audio, audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
24000,
"wav",
is_first_chunk=True,
is_last_chunk=True,
stream=False
) )
return Response( return Response(
@ -115,18 +115,16 @@ async def generate_from_phonemes(
headers={ headers={
"Content-Disposition": "attachment; filename=speech.wav", "Content-Disposition": "attachment; filename=speech.wav",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
} },
) )
except ValueError as e: except ValueError as e:
logger.error(f"Invalid request: {str(e)}") logger.error(f"Invalid request: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail={"error": "Invalid request", "message": str(e)}
detail={"error": "Invalid request", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Error generating audio: {str(e)}") logger.error(f"Error generating audio: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )

View file

@ -1,13 +1,12 @@
from typing import List, Union from typing import List, Union, AsyncGenerator
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException from fastapi import Header, Depends, Response, APIRouter, HTTPException
from fastapi import Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from ..services.tts_service import TTSService
from ..services.audio import AudioService from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
from typing import AsyncGenerator from ..services.tts_service import TTSService
router = APIRouter( router = APIRouter(
tags=["OpenAI Compatible TTS"], tags=["OpenAI Compatible TTS"],
@ -20,7 +19,9 @@ def get_tts_service() -> TTSService:
return TTSService() # Initialize TTSService with default settings return TTSService() # Initialize TTSService with default settings
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str: async def process_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input into a combined voice, handling both string and list formats""" """Process voice input into a combined voice, handling both string and list formats"""
# Convert input to list of voices # Convert input to list of voices
if isinstance(voice_input, str): if isinstance(voice_input, str):
@ -35,7 +36,9 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer
available_voices = await tts_service.list_voices() available_voices = await tts_service.list_voices()
for voice in voices: for voice in voices:
if voice not in available_voices: if voice not in available_voices:
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}") raise ValueError(
f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# If single voice, return it directly # If single voice, return it directly
if len(voices) == 1: if len(voices) == 1:
@ -45,14 +48,16 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer
return await tts_service.combine_voices(voices=voices) return await tts_service.combine_voices(voices=voices)
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]: async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest
) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated""" """Stream audio chunks as they're generated"""
voice_to_use = await process_voices(request.voice, tts_service) voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream( async for chunk in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=voice_to_use, voice=voice_to_use,
speed=request.speed, speed=request.speed,
output_format=request.response_format output_format=request.response_format,
): ):
yield chunk yield chunk
@ -101,11 +106,8 @@ async def create_speech(
# Convert to requested format # Convert to requested format
content = AudioService.convert_audio( content = AudioService.convert_audio(
audio, audio, 24000, request.response_format, is_first_chunk=True, stream=False
24000, )
request.response_format,
is_first_chunk=True,
stream=False)
return Response( return Response(
content=content, content=content,

View file

@ -6,17 +6,22 @@ import numpy as np
import soundfile as sf import soundfile as sf
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
from loguru import logger from loguru import logger
from ..core.config import settings from ..core.config import settings
class AudioNormalizer: class AudioNormalizer:
"""Handles audio normalization state for a single stream""" """Handles audio normalization state for a single stream"""
def __init__(self): def __init__(self):
self.int16_max = np.iinfo(np.int16).max self.int16_max = np.iinfo(np.int16).max
self.chunk_trim_ms = settings.gap_trim_ms self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
def normalize(self, audio_data: np.ndarray, is_last_chunk: bool = False) -> np.ndarray: def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Normalize audio data to int16 range and trim chunk boundaries""" """Normalize audio data to int16 range and trim chunk boundaries"""
# Convert to float32 if not already # Convert to float32 if not already
audio_float = audio_data.astype(np.float32) audio_float = audio_data.astype(np.float32)
@ -27,11 +32,12 @@ class AudioNormalizer:
# Trim end of non-final chunks to reduce gaps # Trim end of non-final chunks to reduce gaps
if not is_last_chunk and len(audio_float) > self.samples_to_trim: if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[:-self.samples_to_trim] audio_float = audio_float[: -self.samples_to_trim]
# Scale to int16 range # Scale to int16 range
return (audio_float * self.int16_max).astype(np.int16) return (audio_float * self.int16_max).astype(np.int16)
class AudioService: class AudioService:
"""Service for audio format conversions""" """Service for audio format conversions"""
@ -46,7 +52,7 @@ class AudioService:
}, },
"flac": { "flac": {
"compression_level": 0.0, # Light compression, still fast "compression_level": 0.0, # Light compression, still fast
} },
} }
@staticmethod @staticmethod
@ -58,7 +64,7 @@ class AudioService:
is_last_chunk: bool = False, is_last_chunk: bool = False,
normalizer: AudioNormalizer = None, normalizer: AudioNormalizer = None,
format_settings: dict = None, format_settings: dict = None,
stream: bool = True stream: bool = True,
) -> bytes: ) -> bytes:
"""Convert audio data to specified format """Convert audio data to specified format
@ -90,37 +96,55 @@ class AudioService:
# Always normalize audio to ensure proper amplitude scaling # Always normalize audio to ensure proper amplitude scaling
if normalizer is None: if normalizer is None:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
normalized_audio = normalizer.normalize(audio_data, is_last_chunk=is_last_chunk) normalized_audio = normalizer.normalize(
audio_data, is_last_chunk=is_last_chunk
)
if output_format == "pcm": if output_format == "pcm":
# Raw 16-bit PCM samples, no header # Raw 16-bit PCM samples, no header
buffer.write(normalized_audio.tobytes()) buffer.write(normalized_audio.tobytes())
elif output_format == "wav": elif output_format == "wav":
# Always use soundfile for WAV to ensure proper headers and normalization # Always use soundfile for WAV to ensure proper headers and normalization
sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') sf.write(
buffer,
normalized_audio,
sample_rate,
format="WAV",
subtype="PCM_16",
)
elif output_format == "mp3": elif output_format == "mp3":
# Use format settings or defaults # Use format settings or defaults
settings = format_settings.get("mp3", {}) if format_settings else {} settings = format_settings.get("mp3", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings} settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
sf.write( sf.write(
buffer, normalized_audio, buffer, normalized_audio, sample_rate, format="MP3", **settings
sample_rate, format="MP3", )
**settings
)
elif output_format == "opus": elif output_format == "opus":
settings = format_settings.get("opus", {}) if format_settings else {} settings = format_settings.get("opus", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings} settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
sf.write(buffer, normalized_audio, sample_rate, format="OGG", sf.write(
subtype="OPUS", **settings) buffer,
normalized_audio,
sample_rate,
format="OGG",
subtype="OPUS",
**settings,
)
elif output_format == "flac": elif output_format == "flac":
if is_first_chunk: if is_first_chunk:
logger.info("Starting FLAC stream...") logger.info("Starting FLAC stream...")
settings = format_settings.get("flac", {}) if format_settings else {} settings = format_settings.get("flac", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings} settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
sf.write(buffer, normalized_audio, sample_rate, format="FLAC", sf.write(
subtype='PCM_16', **settings) buffer,
normalized_audio,
sample_rate,
format="FLAC",
subtype="PCM_16",
**settings,
)
else: else:
if output_format == "aac": if output_format == "aac":
raise ValueError( raise ValueError(

View file

@ -1,13 +1,13 @@
from .normalizer import normalize_text from .normalizer import normalize_text
from .phonemizer import phonemize, PhonemizerBackend, EspeakBackend from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import tokenize, decode_tokens, VOCAB from .vocabulary import VOCAB, tokenize, decode_tokens
__all__ = [ __all__ = [
'normalize_text', "normalize_text",
'phonemize', "phonemize",
'tokenize', "tokenize",
'decode_tokens', "decode_tokens",
'VOCAB', "VOCAB",
'PhonemizerBackend', "PhonemizerBackend",
'EspeakBackend' "EspeakBackend",
] ]

View file

@ -1,6 +1,7 @@
"""Text chunking service""" """Text chunking service"""
import re import re
from ...core.config import settings from ...core.config import settings

View file

@ -9,19 +9,58 @@ from functools import lru_cache
# Constants # Constants
VALID_TLDS = [ VALID_TLDS = [
"com", "org", "net", "edu", "gov", "mil", "int", "biz", "info", "name", "com",
"pro", "coop", "museum", "travel", "jobs", "mobi", "tel", "asia", "cat", "org",
"xxx", "aero", "arpa", "bg", "br", "ca", "cn", "de", "es", "eu", "fr", "net",
"in", "it", "jp", "mx", "nl", "ru", "uk", "us", "io" "edu",
"gov",
"mil",
"int",
"biz",
"info",
"name",
"pro",
"coop",
"museum",
"travel",
"jobs",
"mobi",
"tel",
"asia",
"cat",
"xxx",
"aero",
"arpa",
"bg",
"br",
"ca",
"cn",
"de",
"es",
"eu",
"fr",
"in",
"it",
"jp",
"mx",
"nl",
"ru",
"uk",
"us",
"io",
] ]
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
EMAIL_PATTERN = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE) EMAIL_PATTERN = re.compile(
URL_PATTERN = re.compile( r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" +
"|".join(VALID_TLDS) + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
re.IGNORECASE
) )
URL_PATTERN = re.compile(
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:"
+ "|".join(VALID_TLDS)
+ "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
re.IGNORECASE,
)
def split_num(num: re.Match[str]) -> str: def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats""" """Handle number splitting for various formats"""
@ -47,6 +86,7 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} oh {right}{s}" return f"{left} oh {right}{s}"
return f"{left} {right}{s}" return f"{left} {right}{s}"
def handle_money(m: re.Match[str]) -> str: def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form""" """Convert money expressions to spoken form"""
m = m.group() m = m.group()
@ -66,21 +106,24 @@ def handle_money(m: re.Match[str]) -> str:
) )
return f"{b} {bill}{s} and {c} {coins}" return f"{b} {bill}{s} and {c} {coins}"
def handle_decimal(num: re.Match[str]) -> str: def handle_decimal(num: re.Match[str]) -> str:
"""Convert decimal numbers to spoken form""" """Convert decimal numbers to spoken form"""
a, b = num.group().split(".") a, b = num.group().split(".")
return " point ".join([a, " ".join(b)]) return " point ".join([a, " ".join(b)])
def handle_email(m: re.Match[str]) -> str: def handle_email(m: re.Match[str]) -> str:
"""Convert email addresses into speakable format""" """Convert email addresses into speakable format"""
email = m.group(0) email = m.group(0)
parts = email.split('@') parts = email.split("@")
if len(parts) == 2: if len(parts) == 2:
user, domain = parts user, domain = parts
domain = domain.replace('.', ' dot ') domain = domain.replace(".", " dot ")
return f"{user} at {domain}" return f"{user} at {domain}"
return email return email
def handle_url(u: re.Match[str]) -> str: def handle_url(u: re.Match[str]) -> str:
"""Make URLs speakable by converting special characters to spoken words""" """Make URLs speakable by converting special characters to spoken words"""
if not u: if not u:
@ -89,19 +132,24 @@ def handle_url(u: re.Match[str]) -> str:
url = u.group(0).strip() url = u.group(0).strip()
# Handle protocol first # Handle protocol first
url = re.sub(r'^https?://', lambda a: 'https ' if 'https' in a.group() else 'http ', url, flags=re.IGNORECASE) url = re.sub(
url = re.sub(r'^www\.', 'www ', url, flags=re.IGNORECASE) r"^https?://",
lambda a: "https " if "https" in a.group() else "http ",
url,
flags=re.IGNORECASE,
)
url = re.sub(r"^www\.", "www ", url, flags=re.IGNORECASE)
# Handle port numbers before other replacements # Handle port numbers before other replacements
url = re.sub(r':(\d+)(?=/|$)', lambda m: f" colon {m.group(1)}", url) url = re.sub(r":(\d+)(?=/|$)", lambda m: f" colon {m.group(1)}", url)
# Split into domain and path # Split into domain and path
parts = url.split('/', 1) parts = url.split("/", 1)
domain = parts[0] domain = parts[0]
path = parts[1] if len(parts) > 1 else '' path = parts[1] if len(parts) > 1 else ""
# Handle dots in domain # Handle dots in domain
domain = domain.replace('.', ' dot ') domain = domain.replace(".", " dot ")
# Reconstruct URL # Reconstruct URL
if path: if path:
@ -120,7 +168,7 @@ def handle_url(u: re.Match[str]) -> str:
url = url.replace("/", " slash ") # Handle any remaining slashes url = url.replace("/", " slash ") # Handle any remaining slashes
# Clean up extra spaces # Clean up extra spaces
return re.sub(r'\s+', ' ', url).strip() return re.sub(r"\s+", " ", url).strip()
def normalize_urls(text: str) -> str: def normalize_urls(text: str) -> str:
@ -133,6 +181,7 @@ def normalize_urls(text: str) -> str:
return text return text
def normalize_text(text: str) -> str: def normalize_text(text: str) -> str:
"""Normalize text for TTS processing""" """Normalize text for TTS processing"""
# Pre-process URLs first # Pre-process URLs first
@ -165,9 +214,7 @@ def normalize_text(text: str) -> str:
# Handle numbers and money # Handle numbers and money
text = re.sub( text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
split_num,
text
) )
text = re.sub(r"(?<=\d),(?=\d)", "", text) text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub( text = re.sub(
@ -183,9 +230,7 @@ def normalize_text(text: str) -> str:
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text) text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", "s", text) text = re.sub(r"(?<=X')S\b", "s", text)
text = re.sub( text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
lambda m: m.group().replace(".", "-"),
text
) )
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)

View file

@ -1,8 +1,11 @@
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import phonemizer import phonemizer
from .normalizer import normalize_text from .normalizer import normalize_text
class PhonemizerBackend(ABC): class PhonemizerBackend(ABC):
"""Abstract base class for phonemization backends""" """Abstract base class for phonemization backends"""
@ -18,6 +21,7 @@ class PhonemizerBackend(ABC):
""" """
pass pass
class EspeakBackend(PhonemizerBackend): class EspeakBackend(PhonemizerBackend):
"""Espeak-based phonemizer implementation""" """Espeak-based phonemizer implementation"""
@ -28,9 +32,7 @@ class EspeakBackend(PhonemizerBackend):
language: Language code ('en-us' or 'en-gb') language: Language code ('en-us' or 'en-gb')
""" """
self.backend = phonemizer.backend.EspeakBackend( self.backend = phonemizer.backend.EspeakBackend(
language=language, language=language, preserve_punctuation=True, with_stress=True
preserve_punctuation=True,
with_stress=True
) )
self.language = language self.language = language
@ -59,6 +61,7 @@ class EspeakBackend(PhonemizerBackend):
return ps.strip() return ps.strip()
def create_phonemizer(language: str = "a") -> PhonemizerBackend: def create_phonemizer(language: str = "a") -> PhonemizerBackend:
"""Factory function to create phonemizer backend """Factory function to create phonemizer backend
@ -69,16 +72,14 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
Phonemizer backend instance Phonemizer backend instance
""" """
# Map language codes to espeak language codes # Map language codes to espeak language codes
lang_map = { lang_map = {"a": "en-us", "b": "en-gb"}
"a": "en-us",
"b": "en-gb"
}
if language not in lang_map: if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}") raise ValueError(f"Unsupported language code: {language}")
return EspeakBackend(lang_map[language]) return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
"""Convert text to phonemes """Convert text to phonemes

View file

@ -9,9 +9,11 @@ def get_vocab():
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
return {symbol: i for i, symbol in enumerate(symbols)} return {symbol: i for i, symbol in enumerate(symbols)}
# Initialize vocabulary # Initialize vocabulary
VOCAB = get_vocab() VOCAB = get_vocab()
def tokenize(phonemes: str) -> list[int]: def tokenize(phonemes: str) -> list[int]:
"""Convert phonemes string to token IDs """Convert phonemes string to token IDs
@ -23,6 +25,7 @@ def tokenize(phonemes: str) -> list[int]:
""" """
return [i for i in map(VOCAB.get, phonemes) if i is not None] return [i for i in map(VOCAB.get, phonemes) if i is not None]
def decode_tokens(tokens: list[int]) -> str: def decode_tokens(tokens: list[int]) -> str:
"""Convert token IDs back to phonemes string """Convert token IDs back to phonemes string

View file

@ -2,12 +2,14 @@ import os
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple from typing import List, Tuple
import torch
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from ..core.config import settings from ..core.config import settings
class TTSBaseModel(ABC): class TTSBaseModel(ABC):
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
@ -26,7 +28,9 @@ class TTSBaseModel(ABC):
# Test CUDA device # Test CUDA device
test_tensor = torch.zeros(1).cuda() test_tensor = torch.zeros(1).cuda()
logger.info("CUDA test successful") logger.info("CUDA test successful")
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path) model_path = os.path.join(
settings.model_dir, settings.pytorch_model_path
)
cls._device = "cuda" cls._device = "cuda"
except Exception as e: except Exception as e:
logger.error(f"CUDA test failed: {e}") logger.error(f"CUDA test failed: {e}")
@ -54,19 +58,35 @@ class TTSBaseModel(ABC):
voice_path = os.path.join(cls.VOICES_DIR, file) voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Count voices in directory # Count voices in directory
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]) voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
# Now that model and voices are ready, do warmup # Now that model and voices are ready, do warmup
try: try:
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "don_quixote.txt")) as f: with open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"core",
"don_quixote.txt",
)
) as f:
warmup_text = f.read() warmup_text = f.read()
except Exception as e: except Exception as e:
logger.warning(f"Failed to load warmup text: {e}") logger.warning(f"Failed to load warmup text: {e}")
@ -74,6 +94,7 @@ class TTSBaseModel(ABC):
# Use warmup service after model is fully initialized # Use warmup service after model is fully initialized
from .warmup import WarmupService from .warmup import WarmupService
warmup = WarmupService() warmup = WarmupService()
# Load and warm up voices # Load and warm up voices
@ -83,7 +104,9 @@ class TTSBaseModel(ABC):
logger.info("Model warm-up complete") logger.info("Model warm-up complete")
# Count voices in directory # Count voices in directory
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]) voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return voice_count return voice_count
@classmethod @classmethod
@ -108,7 +131,9 @@ class TTSBaseModel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> Tuple[np.ndarray, str]: def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> Tuple[np.ndarray, str]:
"""Generate audio from text """Generate audio from text
Args: Args:
@ -124,7 +149,9 @@ class TTSBaseModel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def generate_from_tokens(cls, tokens: List[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: def generate_from_tokens(
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens """Generate audio from tokens
Args: Args:

View file

@ -1,12 +1,19 @@
import os import os
import numpy as np import numpy as np
import torch import torch
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
from loguru import logger from loguru import logger
from onnxruntime import (
ExecutionMode,
SessionOptions,
InferenceSession,
GraphOptimizationLevel,
)
from .tts_base import TTSBaseModel from .tts_base import TTSBaseModel
from .text_processing import phonemize, tokenize
from ..core.config import settings from ..core.config import settings
from .text_processing import tokenize, phonemize
class TTSCPUModel(TTSBaseModel): class TTSCPUModel(TTSBaseModel):
_instance = None _instance = None
@ -41,11 +48,17 @@ class TTSCPUModel(TTSBaseModel):
# Set optimization level # Set optimization level
if settings.onnx_optimization_level == "all": if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic": elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else: else:
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
)
# Configure threading # Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads session_options.intra_op_num_threads = settings.onnx_num_threads
@ -63,17 +76,17 @@ class TTSCPUModel(TTSBaseModel):
# Configure CPU provider options # Configure CPU provider options
provider_options = { provider_options = {
'CPUExecutionProvider': { "CPUExecutionProvider": {
'arena_extend_strategy': settings.onnx_arena_extend_strategy, "arena_extend_strategy": settings.onnx_arena_extend_strategy,
'cpu_memory_arena_cfg': 'cpu:0' "cpu_memory_arena_cfg": "cpu:0",
} }
} }
session = InferenceSession( session = InferenceSession(
onnx_path, onnx_path,
sess_options=session_options, sess_options=session_options,
providers=['CPUExecutionProvider'], providers=["CPUExecutionProvider"],
provider_options=[provider_options] provider_options=[provider_options],
) )
cls._onnx_session = session cls._onnx_session = session
return session return session
@ -96,7 +109,9 @@ class TTSCPUModel(TTSBaseModel):
return phonemes, tokens return phonemes, tokens
@classmethod @classmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]: def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text """Generate audio from text
Args: Args:
@ -120,7 +135,9 @@ class TTSCPUModel(TTSBaseModel):
return audio, phonemes return audio, phonemes
@classmethod @classmethod
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens """Generate audio from tokens
Args: Args:
@ -136,16 +153,15 @@ class TTSCPUModel(TTSBaseModel):
# Pre-allocate and prepare inputs # Pre-allocate and prepare inputs
tokens_input = np.array([tokens], dtype=np.int64) tokens_input = np.array([tokens], dtype=np.int64)
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions style_input = voicepack[
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed len(tokens) - 2
].numpy() # Already has correct dimensions
speed_input = np.full(
1, speed, dtype=np.float32
) # More efficient than ones * speed
# Run inference with optimized inputs # Run inference with optimized inputs
result = cls._onnx_session.run( result = cls._onnx_session.run(
None, None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
{
'tokens': tokens_input,
'style': style_input,
'speed': speed_input
}
) )
return result[0] return result[0]

View file

@ -1,13 +1,15 @@
import os import os
import time
import numpy as np import numpy as np
import torch import torch
import time
from loguru import logger from loguru import logger
from models import build_model from models import build_model
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel from .tts_base import TTSBaseModel
from ..core.config import settings from ..core.config import settings
from .text_processing import tokenize, phonemize
# @torch.no_grad() # @torch.no_grad()
# def forward(model, tokens, ref_s, speed): # def forward(model, tokens, ref_s, speed):
@ -65,7 +67,7 @@ def forward(model, tokens, ref_s, speed):
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device) pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device)
c_frame = 0 c_frame = 0
for i in range(pred_aln_trg.size(0)): for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item() c_frame += pred_dur[0, i].item()
# Matrix multiplications - reuse unsqueezed tensor # Matrix multiplications - reuse unsqueezed tensor
@ -79,6 +81,7 @@ def forward(model, tokens, ref_s, speed):
return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy() return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy()
# def length_to_mask(lengths): # def length_to_mask(lengths):
# """Create attention mask from lengths""" # """Create attention mask from lengths"""
# mask = ( # mask = (
@ -90,17 +93,21 @@ def forward(model, tokens, ref_s, speed):
# mask = torch.gt(mask + 1, lengths.unsqueeze(1)) # mask = torch.gt(mask + 1, lengths.unsqueeze(1))
# return mask # return mask
def length_to_mask(lengths): def length_to_mask(lengths):
"""Create attention mask from lengths - possibly optimized version""" """Create attention mask from lengths - possibly optimized version"""
max_len = lengths.max() max_len = lengths.max()
# Create mask directly on the same device as lengths # Create mask directly on the same device as lengths
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1) mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
# Avoid type_as by using the correct dtype from the start # Avoid type_as by using the correct dtype from the start
if lengths.dtype != mask.dtype: if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype) mask = mask.to(dtype=lengths.dtype)
# Fuse operations using broadcasting # Fuse operations using broadcasting
return mask + 1 > lengths[:, None] return mask + 1 > lengths[:, None]
class TTSGPUModel(TTSBaseModel): class TTSGPUModel(TTSBaseModel):
_instance = None _instance = None
_device = "cuda" _device = "cuda"
@ -143,7 +150,9 @@ class TTSGPUModel(TTSBaseModel):
return phonemes, tokens return phonemes, tokens
@classmethod @classmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]: def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text """Generate audio from text
Args: Args:
@ -167,7 +176,9 @@ class TTSGPUModel(TTSBaseModel):
return audio, phonemes return audio, phonemes
@classmethod @classmethod
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens """Generate audio from tokens
Args: Args:

View file

@ -1,5 +1,4 @@
import io import io
import aiofiles.os
import os import os
import re import re
import time import time
@ -8,13 +7,14 @@ from functools import lru_cache
import numpy as np import numpy as np
import torch import torch
import aiofiles.os
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
from .text_processing import normalize_text, chunker
from loguru import logger from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .audio import AudioService, AudioNormalizer from .audio import AudioService, AudioNormalizer
from .tts_model import TTSModel
from ..core.config import settings
from .text_processing import chunker, normalize_text
class TTSService: class TTSService:
@ -26,7 +26,9 @@ class TTSService:
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices @lru_cache(maxsize=3) # Cache up to 3 most recently used voices
def _load_voice(voice_path: str) -> torch.Tensor: def _load_voice(voice_path: str) -> torch.Tensor:
"""Load and cache a voice model""" """Load and cache a voice model"""
return torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) return torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
def _get_voice_path(self, voice_name: str) -> Optional[str]: def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file""" """Get the path to a voice file"""
@ -37,7 +39,9 @@ class TTSService:
self, text: str, voice: str, speed: float, stitch_long_output: bool = True self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]: ) -> Tuple[torch.Tensor, float]:
"""Generate complete audio and return with processing time""" """Generate complete audio and return with processing time"""
audio, processing_time = self._generate_audio_internal(text, voice, speed, stitch_long_output) audio, processing_time = self._generate_audio_internal(
text, voice, speed, stitch_long_output
)
return audio, processing_time return audio, processing_time
def _generate_audio_internal( def _generate_audio_internal(
@ -72,7 +76,9 @@ class TTSService:
phonemes, tokens = TTSModel.process_text(chunk, voice[0]) phonemes, tokens = TTSModel.process_text(chunk, voice[0])
chunks_data.append((chunk, tokens)) chunks_data.append((chunk, tokens))
except Exception as e: except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") logger.error(
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
)
continue continue
if not chunks_data: if not chunks_data:
@ -82,20 +88,28 @@ class TTSService:
audio_chunks = [] audio_chunks = []
for chunk, tokens in chunks_data: for chunk, tokens in chunks_data:
try: try:
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
if chunk_audio is not None: if chunk_audio is not None:
audio_chunks.append(chunk_audio) audio_chunks.append(chunk_audio)
else: else:
logger.error(f"No audio generated for chunk: '{chunk}'") logger.error(f"No audio generated for chunk: '{chunk}'")
except Exception as e: except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}") logger.error(
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
)
continue continue
if not audio_chunks: if not audio_chunks:
raise ValueError("No audio chunks were generated successfully") raise ValueError("No audio chunks were generated successfully")
# Concatenate all chunks # Concatenate all chunks
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else: else:
# Process single chunk # Process single chunk
phonemes, tokens = TTSModel.process_text(text, voice[0]) phonemes, tokens = TTSModel.process_text(text, voice[0])
@ -109,7 +123,12 @@ class TTSService:
raise raise
async def generate_audio_stream( async def generate_audio_stream(
self, text: str, voice: str, speed: float, output_format: str = "wav", silent=False self,
text: str,
voice: str,
speed: float,
output_format: str = "wav",
silent=False,
): ):
"""Generate and yield audio chunks as they're generated for real-time streaming""" """Generate and yield audio chunks as they're generated for real-time streaming"""
try: try:
@ -125,7 +144,9 @@ class TTSService:
if not normalized: if not normalized:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
logger.debug(f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms") logger.debug(
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
)
# Voice validation and loading # Voice validation and loading
voice_start = time.time() voice_start = time.time()
@ -133,7 +154,9 @@ class TTSService:
if not voice_path: if not voice_path:
raise ValueError(f"Voice not found: {voice}") raise ValueError(f"Voice not found: {voice}")
voicepack = self._load_voice(voice_path) voicepack = self._load_voice(voice_path)
logger.debug(f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms") logger.debug(
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
)
# Process chunks as they're generated # Process chunks as they're generated
is_first = True is_first = True
@ -149,7 +172,9 @@ class TTSService:
try: try:
# Process text and generate audio # Process text and generate audio
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0]) phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
if chunk_audio is not None: if chunk_audio is not None:
# Convert chunk with proper header handling # Convert chunk with proper header handling
@ -159,7 +184,7 @@ class TTSService:
output_format, output_format,
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=stream_normalizer, normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None) # Last if no next chunk is_last_chunk=(next_chunk is None), # Last if no next chunk
) )
yield chunk_bytes yield chunk_bytes
@ -168,7 +193,9 @@ class TTSService:
logger.error(f"No audio generated for chunk: '{current_chunk}'") logger.error(f"No audio generated for chunk: '{current_chunk}'")
except Exception as e: except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}") logger.error(
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
)
current_chunk = next_chunk # Move to next chunk current_chunk = next_chunk # Move to next chunk

View file

@ -1,10 +1,11 @@
import os import os
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from loguru import logger from loguru import logger
from .tts_service import TTSService
from .tts_model import TTSModel from .tts_model import TTSModel
from .tts_service import TTSService
from ..core.config import settings from ..core.config import settings
@ -22,18 +23,19 @@ class WarmupService:
"""Load and cache voices up to LRU limit""" """Load and cache voices up to LRU limit"""
# Get all voices sorted by filename length (shorter names first, usually base voices) # Get all voices sorted by filename length (shorter names first, usually base voices)
voice_files = sorted( voice_files = sorted(
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], [f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
key=len
) )
n_voices_cache=1 n_voices_cache = 1
loaded_voices = [] loaded_voices = []
for voice_file in voice_files[:n_voices_cache]: for voice_file in voice_files[:n_voices_cache]:
try: try:
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file) voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
# load using service, lru cache # load using service, lru cache
voicepack = self.tts_service._load_voice(voice_path) voicepack = self.tts_service._load_voice(voice_path)
loaded_voices.append((voice_file[:-3], voicepack)) # Store name and tensor loaded_voices.append(
(voice_file[:-3], voicepack)
) # Store name and tensor
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) # voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
# logger.info(f"Loaded voice {voice_file[:-3]} into cache") # logger.info(f"Loaded voice {voice_file[:-3]} into cache")
except Exception as e: except Exception as e:
@ -41,17 +43,16 @@ class WarmupService:
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache") logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
return loaded_voices return loaded_voices
async def warmup_voices(self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]): async def warmup_voices(
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
):
"""Warm up voice inference and streaming""" """Warm up voice inference and streaming"""
n_warmups = 1 n_warmups = 1
for voice_name, _ in loaded_voices[:n_warmups]: for voice_name, _ in loaded_voices[:n_warmups]:
try: try:
logger.info(f"Running warmup inference on voice {voice_name}") logger.info(f"Running warmup inference on voice {voice_name}")
async for _ in self.tts_service.generate_audio_stream( async for _ in self.tts_service.generate_audio_stream(
warmup_text, warmup_text, voice_name, 1.0, "pcm"
voice_name,
1.0,
"pcm"
): ):
pass # Process all chunks to properly warm up pass # Process all chunks to properly warm up
logger.info(f"Completed warmup for voice {voice_name}") logger.info(f"Completed warmup for voice {voice_name}")

View file

@ -1,14 +1,15 @@
from enum import Enum from enum import Enum
from typing import Literal, Union, List from typing import List, Union, Literal
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
class VoiceCombineRequest(BaseModel): class VoiceCombineRequest(BaseModel):
"""Request schema for voice combination endpoint that accepts either a string with + or a list""" """Request schema for voice combination endpoint that accepts either a string with + or a list"""
voices: Union[str, List[str]] = Field( voices: Union[str, List[str]] = Field(
..., ...,
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine" description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine",
) )

View file

@ -1,14 +1,19 @@
from pydantic import BaseModel, Field from pydantic import Field, BaseModel
class PhonemeRequest(BaseModel): class PhonemeRequest(BaseModel):
text: str text: str
language: str = "a" # Default to American English language: str = "a" # Default to American English
class PhonemeResponse(BaseModel): class PhonemeResponse(BaseModel):
phonemes: str phonemes: str
tokens: list[int] tokens: list[int]
class GenerateFromPhonemesRequest(BaseModel): class GenerateFromPhonemesRequest(BaseModel):
phonemes: str phonemes: str
voice: str = Field(..., description="Voice ID to use for generation") voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field(default=1.0, ge=0.1, le=5.0, description="Speed factor for generation") speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
)

View file

@ -1,9 +1,9 @@
import os import os
import sys import sys
import shutil import shutil
from unittest.mock import Mock, patch, MagicMock from unittest.mock import Mock, MagicMock, patch
import numpy as np
import numpy as np
import pytest import pytest
import aiofiles.threadpool import aiofiles.threadpool
@ -37,6 +37,7 @@ mock_torch = Mock()
mock_torch.cuda = Mock() mock_torch.cuda = Mock()
mock_torch.cuda.is_available = Mock(return_value=False) mock_torch.cuda.is_available = Mock(return_value=False)
# Create a mock tensor class that supports basic operations # Create a mock tensor class that supports basic operations
class MockTensor: class MockTensor:
def __init__(self, data): def __init__(self, data):
@ -46,7 +47,7 @@ class MockTensor:
elif isinstance(data, MockTensor): elif isinstance(data, MockTensor):
self.shape = data.shape self.shape = data.shape
else: else:
self.shape = getattr(data, 'shape', [1]) self.shape = getattr(data, "shape", [1])
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(self.data, (list, tuple)): if isinstance(self.data, (list, tuple)):
@ -91,9 +92,12 @@ class MockTensor:
def type_as(self, other): def type_as(self, other):
return self return self
# Add tensor operations to mock torch # Add tensor operations to mock torch
mock_torch.tensor = lambda x: MockTensor(x) mock_torch.tensor = lambda x: MockTensor(x)
mock_torch.zeros = lambda *args: MockTensor([0] * (args[0] if isinstance(args[0], int) else args[0][0])) mock_torch.zeros = lambda *args: MockTensor(
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
)
mock_torch.arange = lambda x: MockTensor(list(range(x))) mock_torch.arange = lambda x: MockTensor(list(range(x)))
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0]) mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
@ -176,7 +180,9 @@ def mock_tts_service(monkeypatch):
# Mock TTSModel.generate_from_tokens since we call it directly # Mock TTSModel.generate_from_tokens since we call it directly
mock_generate = Mock(return_value=np.zeros(48000)) mock_generate = Mock(return_value=np.zeros(48000))
monkeypatch.setattr("api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate) monkeypatch.setattr(
"api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate
)
return mock_service return mock_service

View file

@ -1,8 +1,9 @@
"""Tests for AudioService""" """Tests for AudioService"""
from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
from unittest.mock import patch
from api.src.services.audio import AudioService, AudioNormalizer from api.src.services.audio import AudioService, AudioNormalizer
@ -10,10 +11,11 @@ from api.src.services.audio import AudioService, AudioNormalizer
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_settings(): def mock_settings():
"""Mock settings for all tests""" """Mock settings for all tests"""
with patch('api.src.services.audio.settings') as mock_settings: with patch("api.src.services.audio.settings") as mock_settings:
mock_settings.gap_trim_ms = 250 mock_settings.gap_trim_ms = 250
yield mock_settings yield mock_settings
@pytest.fixture @pytest.fixture
def sample_audio(): def sample_audio():
"""Generate a simple sine wave for testing""" """Generate a simple sine wave for testing"""

View file

@ -1,14 +1,16 @@
"""Tests for text chunking service""" """Tests for text chunking service"""
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from api.src.services.text_processing import chunker from api.src.services.text_processing import chunker
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_settings(): def mock_settings():
"""Mock settings for all tests""" """Mock settings for all tests"""
with patch('api.src.services.text_processing.chunker.settings') as mock_settings: with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
mock_settings.max_chunk_size = 300 mock_settings.max_chunk_size = 300
yield mock_settings yield mock_settings

View file

@ -1,16 +1,17 @@
import asyncio
from unittest.mock import Mock, AsyncMock from unittest.mock import Mock, AsyncMock
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import asyncio
from fastapi.testclient import TestClient
from httpx import AsyncClient from httpx import AsyncClient
from fastapi.testclient import TestClient
from ..src.main import app from ..src.main import app
# Create test client # Create test client
client = TestClient(app) client = TestClient(app)
# Create async client fixture # Create async client fixture
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def async_client(): async def async_client():
@ -28,20 +29,23 @@ def mock_tts_service(monkeypatch):
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]: for chunk in [b"chunk1", b"chunk2"]:
yield chunk yield chunk
mock_service.generate_audio_stream = mock_stream mock_service.generate_audio_stream = mock_stream
# Create async mocks # Create async mocks
mock_service.list_voices = AsyncMock(return_value=[ mock_service.list_voices = AsyncMock(
"af", return_value=[
"bm_lewis", "af",
"bf_isabella", "bm_lewis",
"bf_emma", "bf_isabella",
"af_sarah", "bf_emma",
"af_bella", "af_sarah",
"am_adam", "af_bella",
"am_michael", "am_adam",
"bm_george", "am_michael",
]) "bm_george",
]
)
mock_service.combine_voices = AsyncMock() mock_service.combine_voices = AsyncMock()
monkeypatch.setattr( monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService", "api.src.routers.openai_compatible.TTSService",
@ -54,9 +58,7 @@ def mock_tts_service(monkeypatch):
def mock_audio_service(monkeypatch): def mock_audio_service(monkeypatch):
mock_service = Mock() mock_service = Mock()
mock_service.convert_audio.return_value = b"converted mock audio data" mock_service.convert_audio.return_value = b"converted mock audio data"
monkeypatch.setattr( monkeypatch.setattr("api.src.routers.openai_compatible.AudioService", mock_service)
"api.src.routers.openai_compatible.AudioService", mock_service
)
return mock_service return mock_service
@ -68,7 +70,9 @@ def test_health_check():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client): async def test_openai_speech_endpoint(
mock_tts_service, mock_audio_service, async_client
):
"""Test the OpenAI-compatible speech endpoint""" """Test the OpenAI-compatible speech endpoint"""
test_request = { test_request = {
"model": "kokoro", "model": "kokoro",
@ -76,7 +80,7 @@ async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, asyn
"voice": "bm_lewis", "voice": "bm_lewis",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False, # Explicitly disable streaming
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200 assert response.status_code == 200
@ -97,7 +101,7 @@ async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
"voice": "invalid_voice", "voice": "invalid_voice",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False, # Explicitly disable streaming
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400 # Bad request assert response.status_code == 400 # Bad request
@ -113,7 +117,7 @@ async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",
"speed": -1.0, # Invalid speed "speed": -1.0, # Invalid speed
"stream": False # Explicitly disable streaming "stream": False, # Explicitly disable streaming
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
@ -129,7 +133,7 @@ async def test_openai_speech_generation_error(mock_tts_service, async_client):
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False, # Explicitly disable streaming
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500 assert response.status_code == 500
@ -159,7 +163,9 @@ async def test_combine_voices_string_success(mock_tts_service, async_client):
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah" assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"]) mock_tts_service.combine_voices.assert_called_once_with(
voices=["af_bella", "af_sarah"]
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -184,7 +190,9 @@ async def test_combine_voices_empty_list(mock_tts_service, async_client):
async def test_combine_voices_error(mock_tts_service, async_client): async def test_combine_voices_error(mock_tts_service, async_client):
"""Test error handling in voice combination""" """Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"] test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed")) mock_tts_service.combine_voices = AsyncMock(
side_effect=Exception("Combination failed")
)
response = await async_client.post("/v1/audio/voices/combine", json=test_voices) response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500 assert response.status_code == 500
@ -192,7 +200,9 @@ async def test_combine_voices_error(mock_tts_service, async_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client): async def test_speech_with_combined_voice(
mock_tts_service, mock_audio_service, async_client
):
"""Test speech generation with combined voice using + syntax""" """Test speech generation with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah") mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
@ -202,7 +212,7 @@ async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service,
"voice": "af_bella+af_sarah", "voice": "af_bella+af_sarah",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False "stream": False,
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
@ -213,12 +223,14 @@ async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service,
text="Hello world", text="Hello world",
voice="af_bella_af_sarah", voice="af_bella_af_sarah",
speed=1.0, speed=1.0,
stitch_long_output=True stitch_long_output=True,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client): async def test_speech_with_whitespace_in_voice(
mock_tts_service, mock_audio_service, async_client
):
"""Test speech generation with whitespace in voice combination""" """Test speech generation with whitespace in voice combination"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah") mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
@ -228,14 +240,16 @@ async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_serv
"voice": " af_bella + af_sarah ", "voice": " af_bella + af_sarah ",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False "stream": False,
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav" assert response.headers["content-type"] == "audio/wav"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"]) mock_tts_service.combine_voices.assert_called_once_with(
voices=["af_bella", "af_sarah"]
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -247,7 +261,7 @@ async def test_speech_with_empty_voice_combination(mock_tts_service, async_clien
"voice": "+", "voice": "+",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False "stream": False,
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
@ -264,7 +278,7 @@ async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client
"voice": "invalid+combination", "voice": "invalid+combination",
"response_format": "wav", "response_format": "wav",
"speed": 1.0, "speed": 1.0,
"stream": False "stream": False,
} }
response = await async_client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
@ -282,18 +296,21 @@ async def test_speech_streaming_with_combined_voice(mock_tts_service, async_clie
"input": "Hello world", "input": "Hello world",
"voice": "af_bella+af_sarah", "voice": "af_bella+af_sarah",
"response_format": "mp3", "response_format": "mp3",
"stream": True "stream": True,
} }
# Create streaming mock # Create streaming mock
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]: for chunk in [b"mp3header", b"mp3data"]:
yield chunk yield chunk
mock_tts_service.generate_audio_stream = mock_stream mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header # Add streaming header
headers = {"x-raw-response": "stream"} headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg" assert response.headers["content-type"] == "audio/mpeg"
@ -308,18 +325,21 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "pcm", "response_format": "pcm",
"stream": True "stream": True,
} }
# Create streaming mock for this test # Create streaming mock for this test
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]: for chunk in [b"chunk1", b"chunk2"]:
yield chunk yield chunk
mock_tts_service.generate_audio_stream = mock_stream mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header # Add streaming header
headers = {"x-raw-response": "stream"} headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm" assert response.headers["content-type"] == "audio/pcm"
@ -333,18 +353,21 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "mp3", "response_format": "mp3",
"stream": True "stream": True,
} }
# Create streaming mock for this test # Create streaming mock for this test
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]: for chunk in [b"mp3header", b"mp3data"]:
yield chunk yield chunk
mock_tts_service.generate_audio_stream = mock_stream mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header # Add streaming header
headers = {"x-raw-response": "stream"} headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg" assert response.headers["content-type"] == "audio/mpeg"
@ -359,18 +382,21 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "pcm", "response_format": "pcm",
"stream": True "stream": True,
} }
# Create streaming mock for this test # Create streaming mock for this test
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]: for chunk in [b"chunk1", b"chunk2"]:
yield chunk yield chunk
mock_tts_service.generate_audio_stream = mock_stream mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header # Add streaming header
headers = {"x-raw-response": "stream"} headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm" assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,6 +1,6 @@
"""Tests for FastAPI application""" """Tests for FastAPI application"""
from unittest.mock import MagicMock, patch, call from unittest.mock import MagicMock, call, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -32,6 +32,7 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
# Create async mock # Create async mock
async def async_setup(): async def async_setup():
return 3 return 3
mock_tts_model.setup = MagicMock() mock_tts_model.setup = MagicMock()
mock_tts_model.setup.side_effect = async_setup mock_tts_model.setup.side_effect = async_setup
mock_tts_model.get_device.return_value = "cuda" mock_tts_model.get_device.return_value = "cuda"
@ -90,6 +91,7 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
# Create async mock # Create async mock
async def async_setup(): async def async_setup():
return 2 return 2
mock_tts_model.setup = MagicMock() mock_tts_model.setup = MagicMock()
mock_tts_model.setup.side_effect = async_setup mock_tts_model.setup.side_effect = async_setup
mock_tts_model.get_device.return_value = "cuda" mock_tts_model.get_device.return_value = "cuda"

View file

@ -1,43 +1,88 @@
"""Tests for text normalization service""" """Tests for text normalization service"""
import pytest import pytest
from api.src.services.text_processing.normalizer import normalize_text from api.src.services.text_processing.normalizer import normalize_text
def test_url_protocols(): def test_url_protocols():
"""Test URL protocol handling""" """Test URL protocol handling"""
assert normalize_text("Check out https://example.com") == "Check out https example dot com" assert (
normalize_text("Check out https://example.com")
== "Check out https example dot com"
)
assert normalize_text("Visit http://site.com") == "Visit http site dot com" assert normalize_text("Visit http://site.com") == "Visit http site dot com"
assert normalize_text("Go to https://test.org/path") == "Go to https test dot org slash path" assert (
normalize_text("Go to https://test.org/path")
== "Go to https test dot org slash path"
)
def test_url_www(): def test_url_www():
"""Test www prefix handling""" """Test www prefix handling"""
assert normalize_text("Go to www.example.com") == "Go to www example dot com" assert normalize_text("Go to www.example.com") == "Go to www example dot com"
assert normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs" assert (
assert normalize_text("Check www.site.com?q=test") == "Check www site dot com question-mark q equals test" normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs"
)
assert (
normalize_text("Check www.site.com?q=test")
== "Check www site dot com question-mark q equals test"
)
def test_url_localhost(): def test_url_localhost():
"""Test localhost URL handling""" """Test localhost URL handling"""
assert normalize_text("Running on localhost:7860") == "Running on localhost colon 78 60" assert (
assert normalize_text("Server at localhost:8080/api") == "Server at localhost colon 80 80 slash api" normalize_text("Running on localhost:7860")
assert normalize_text("Test localhost:3000/test?v=1") == "Test localhost colon 3000 slash test question-mark v equals 1" == "Running on localhost colon 78 60"
)
assert (
normalize_text("Server at localhost:8080/api")
== "Server at localhost colon 80 80 slash api"
)
assert (
normalize_text("Test localhost:3000/test?v=1")
== "Test localhost colon 3000 slash test question-mark v equals 1"
)
def test_url_ip_addresses(): def test_url_ip_addresses():
"""Test IP address URL handling""" """Test IP address URL handling"""
assert normalize_text("Access 0.0.0.0:9090/test") == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" assert (
assert normalize_text("API at 192.168.1.1:8000") == "API at 192 dot 168 dot 1 dot 1 colon 8000" normalize_text("Access 0.0.0.0:9090/test")
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
)
assert (
normalize_text("API at 192.168.1.1:8000")
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
)
assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1" assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1"
def test_url_raw_domains(): def test_url_raw_domains():
"""Test raw domain handling""" """Test raw domain handling"""
assert normalize_text("Visit google.com/search") == "Visit google dot com slash search" assert (
assert normalize_text("Go to example.com/path?q=test") == "Go to example dot com slash path question-mark q equals test" normalize_text("Visit google.com/search") == "Visit google dot com slash search"
)
assert (
normalize_text("Go to example.com/path?q=test")
== "Go to example dot com slash path question-mark q equals test"
)
assert normalize_text("Check docs.test.com") == "Check docs dot test dot com" assert normalize_text("Check docs.test.com") == "Check docs dot test dot com"
def test_url_email_addresses(): def test_url_email_addresses():
"""Test email address handling""" """Test email address handling"""
assert normalize_text("Email me at user@example.com") == "Email me at user at example dot com" assert (
normalize_text("Email me at user@example.com")
== "Email me at user at example dot com"
)
assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org" assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org"
assert normalize_text("Send to test.user@site.com") == "Send to test dot user at site dot com" assert (
normalize_text("Send to test.user@site.com")
== "Send to test dot user at site dot com"
)
def test_non_url_text(): def test_non_url_text():
"""Test that non-URL text is unaffected""" """Test that non-URL text is unaffected"""

View file

@ -1,33 +1,36 @@
"""Tests for text processing endpoints""" """Tests for text processing endpoints"""
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import numpy as np
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
import numpy as np
from ..src.main import app
from .conftest import MockTTSModel from .conftest import MockTTSModel
from ..src.main import app
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def async_client(): async def async_client():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac yield ac
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_phonemize_endpoint(async_client): async def test_phonemize_endpoint(async_client):
"""Test phoneme generation endpoint""" """Test phoneme generation endpoint"""
with patch('api.src.routers.text_processing.phonemize') as mock_phonemize, \ with patch("api.src.routers.text_processing.phonemize") as mock_phonemize, patch(
patch('api.src.routers.text_processing.tokenize') as mock_tokenize: "api.src.routers.text_processing.tokenize"
) as mock_tokenize:
# Setup mocks # Setup mocks
mock_phonemize.return_value = "həlˈ" mock_phonemize.return_value = "həlˈ"
mock_tokenize.return_value = [1, 2, 3] mock_tokenize.return_value = [1, 2, 3]
# Test request # Test request
response = await async_client.post("/text/phonemize", json={ response = await async_client.post(
"text": "hello", "/text/phonemize", json={"text": "hello", "language": "a"}
"language": "a" )
})
# Verify response # Verify response
assert response.status_code == 200 assert response.status_code == 200
@ -35,46 +38,55 @@ async def test_phonemize_endpoint(async_client):
assert result["phonemes"] == "həlˈ" assert result["phonemes"] == "həlˈ"
assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_phonemize_empty_text(async_client): async def test_phonemize_empty_text(async_client):
"""Test phoneme generation with empty text""" """Test phoneme generation with empty text"""
response = await async_client.post("/text/phonemize", json={ response = await async_client.post(
"text": "", "/text/phonemize", json={"text": "", "language": "a"}
"language": "a" )
})
assert response.status_code == 500 assert response.status_code == 500
assert "error" in response.json()["detail"] assert "error" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_from_phonemes(async_client, mock_tts_service, mock_audio_service): async def test_generate_from_phonemes(
async_client, mock_tts_service, mock_audio_service
):
"""Test audio generation from phonemes""" """Test audio generation from phonemes"""
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service): with patch(
response = await async_client.post("/text/generate_from_phonemes", json={ "api.src.routers.text_processing.TTSService", return_value=mock_tts_service
"phonemes": "həlˈ", ):
"voice": "af_bella", response = await async_client.post(
"speed": 1.0 "/text/generate_from_phonemes",
}) json={"phonemes": "həlˈ", "voice": "af_bella", "speed": 1.0},
)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav" assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav" assert (
response.headers["content-disposition"] == "attachment; filename=speech.wav"
)
assert response.content == b"mock audio data" assert response.content == b"mock audio data"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service): async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service):
"""Test audio generation with invalid voice""" """Test audio generation with invalid voice"""
mock_tts_service._get_voice_path.return_value = None mock_tts_service._get_voice_path.return_value = None
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service): with patch(
response = await async_client.post("/text/generate_from_phonemes", json={ "api.src.routers.text_processing.TTSService", return_value=mock_tts_service
"phonemes": "həlˈ", ):
"voice": "invalid_voice", response = await async_client.post(
"speed": 1.0 "/text/generate_from_phonemes",
}) json={"phonemes": "həlˈ", "voice": "invalid_voice", "speed": 1.0},
)
assert response.status_code == 400 assert response.status_code == 400
assert "Voice not found" in response.json()["detail"]["message"] assert "Voice not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch): async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
"""Test audio generation with invalid speed""" """Test audio generation with invalid speed"""
@ -82,25 +94,29 @@ async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
mock_model = Mock() mock_model = Mock()
mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000)) mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000))
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model) monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_instance", Mock(return_value=mock_model)) monkeypatch.setattr(
"api.src.services.tts_model.TTSModel.get_instance",
Mock(return_value=mock_model),
)
response = await async_client.post("/text/generate_from_phonemes", json={ response = await async_client.post(
"phonemes": "həlˈ", "/text/generate_from_phonemes",
"voice": "af_bella", json={"phonemes": "həlˈ", "voice": "af_bella", "speed": -1.0},
"speed": -1.0 )
})
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service): async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
"""Test audio generation with empty phonemes""" """Test audio generation with empty phonemes"""
with patch('api.src.routers.text_processing.TTSService', return_value=mock_tts_service): with patch(
response = await async_client.post("/text/generate_from_phonemes", json={ "api.src.routers.text_processing.TTSService", return_value=mock_tts_service
"phonemes": "", ):
"voice": "af_bella", response = await async_client.post(
"speed": 1.0 "/text/generate_from_phonemes",
}) json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
)
assert response.status_code == 400 assert response.status_code == 400
assert "Invalid request" in response.json()["detail"]["error"] assert "Invalid request" in response.json()["detail"]["error"]

View file

@ -1,13 +1,16 @@
"""Tests for TTS model implementations""" """Tests for TTS model implementations"""
import os import os
from unittest.mock import MagicMock, patch
import numpy as np
import torch import torch
import pytest import pytest
import numpy as np
from unittest.mock import patch, MagicMock
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
from api.src.services.tts_base import TTSBaseModel
# Base Model Tests # Base Model Tests
def test_get_device_error(): def test_get_device_error():
@ -16,14 +19,17 @@ def test_get_device_error():
with pytest.raises(RuntimeError, match="Model not initialized"): with pytest.raises(RuntimeError, match="Model not initialized"):
TTSBaseModel.get_device() TTSBaseModel.get_device()
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('torch.cuda.is_available') @patch("torch.cuda.is_available")
@patch('os.path.exists') @patch("os.path.exists")
@patch('os.path.join') @patch("os.path.join")
@patch('os.listdir') @patch("os.listdir")
@patch('torch.load') @patch("torch.load")
@patch('torch.save') @patch("torch.save")
async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available): async def test_setup_cuda_available(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA available""" """Test setup with CUDA available"""
TTSBaseModel._device = None TTSBaseModel._device = None
mock_cuda_available.return_value = True mock_cuda_available.return_value = True
@ -35,7 +41,7 @@ async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_joi
# Create mock model # Create mock model
mock_model = MagicMock() mock_model = MagicMock()
mock_model.bert = MagicMock() mock_model.bert = MagicMock()
mock_model.process_text = MagicMock(return_value=("dummy", [1,2,3])) mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
# Mock initialize to return our mock model # Mock initialize to return our mock model
@ -46,14 +52,17 @@ async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_joi
assert TTSBaseModel._device == "cuda" assert TTSBaseModel._device == "cuda"
assert voice_count == 2 assert voice_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('torch.cuda.is_available') @patch("torch.cuda.is_available")
@patch('os.path.exists') @patch("os.path.exists")
@patch('os.path.join') @patch("os.path.join")
@patch('os.listdir') @patch("os.listdir")
@patch('torch.load') @patch("torch.load")
@patch('torch.save') @patch("torch.save")
async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available): async def test_setup_cuda_unavailable(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA unavailable""" """Test setup with CUDA unavailable"""
TTSBaseModel._device = None TTSBaseModel._device = None
mock_cuda_available.return_value = False mock_cuda_available.return_value = False
@ -65,7 +74,7 @@ async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_j
# Create mock model # Create mock model
mock_model = MagicMock() mock_model = MagicMock()
mock_model.bert = MagicMock() mock_model.bert = MagicMock()
mock_model.process_text = MagicMock(return_value=("dummy", [1,2,3])) mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
# Mock initialize to return our mock model # Mock initialize to return our mock model
@ -76,15 +85,18 @@ async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_j
assert TTSBaseModel._device == "cpu" assert TTSBaseModel._device == "cpu"
assert voice_count == 2 assert voice_count == 2
# CPU Model Tests # CPU Model Tests
def test_cpu_initialize_missing_model(): def test_cpu_initialize_missing_model():
"""Test CPU initialize with missing model""" """Test CPU initialize with missing model"""
TTSCPUModel._onnx_session = None # Reset the session TTSCPUModel._onnx_session = None # Reset the session
with patch('os.path.exists', return_value=False), \ with patch("os.path.exists", return_value=False), patch(
patch('onnxruntime.InferenceSession', return_value=None): "onnxruntime.InferenceSession", return_value=None
):
result = TTSCPUModel.initialize("dummy_dir") result = TTSCPUModel.initialize("dummy_dir")
assert result is None assert result is None
def test_cpu_generate_uninitialized(): def test_cpu_generate_uninitialized():
"""Test CPU generate methods with uninitialized model""" """Test CPU generate methods with uninitialized model"""
TTSCPUModel._onnx_session = None TTSCPUModel._onnx_session = None
@ -93,13 +105,14 @@ def test_cpu_generate_uninitialized():
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0) TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="ONNX model not initialized"): with pytest.raises(RuntimeError, match="ONNX model not initialized"):
TTSCPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0) TTSCPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
def test_cpu_process_text(): def test_cpu_process_text():
"""Test CPU process_text functionality""" """Test CPU process_text functionality"""
with patch('api.src.services.tts_cpu.phonemize') as mock_phonemize, \ with patch("api.src.services.tts_cpu.phonemize") as mock_phonemize, patch(
patch('api.src.services.tts_cpu.tokenize') as mock_tokenize: "api.src.services.tts_cpu.tokenize"
) as mock_tokenize:
mock_phonemize.return_value = "test phonemes" mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3] mock_tokenize.return_value = [1, 2, 3]
@ -107,8 +120,9 @@ def test_cpu_process_text():
assert phonemes == "test phonemes" assert phonemes == "test phonemes"
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
# GPU Model Tests # GPU Model Tests
@patch('torch.cuda.is_available') @patch("torch.cuda.is_available")
def test_gpu_initialize_cuda_unavailable(mock_cuda_available): def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
"""Test GPU initialize with CUDA unavailable""" """Test GPU initialize with CUDA unavailable"""
mock_cuda_available.return_value = False mock_cuda_available.return_value = False
@ -117,14 +131,14 @@ def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
result = TTSGPUModel.initialize("dummy_dir", "dummy_path") result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
assert result is None assert result is None
@patch('api.src.services.tts_gpu.length_to_mask')
@patch("api.src.services.tts_gpu.length_to_mask")
def test_gpu_length_to_mask(mock_length_to_mask): def test_gpu_length_to_mask(mock_length_to_mask):
"""Test length_to_mask function""" """Test length_to_mask function"""
# Setup mock return value # Setup mock return value
expected_mask = torch.tensor([ expected_mask = torch.tensor(
[False, False, False, True, True], [[False, False, False, True, True], [False, False, False, False, False]]
[False, False, False, False, False] )
])
mock_length_to_mask.return_value = expected_mask mock_length_to_mask.return_value = expected_mask
# Call function with test input # Call function with test input
@ -135,6 +149,7 @@ def test_gpu_length_to_mask(mock_length_to_mask):
mock_length_to_mask.assert_called_once() mock_length_to_mask.assert_called_once()
assert torch.equal(mask, expected_mask) assert torch.equal(mask, expected_mask)
def test_gpu_generate_uninitialized(): def test_gpu_generate_uninitialized():
"""Test GPU generate methods with uninitialized model""" """Test GPU generate methods with uninitialized model"""
TTSGPUModel._instance = None TTSGPUModel._instance = None
@ -143,13 +158,14 @@ def test_gpu_generate_uninitialized():
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0) TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="GPU model not initialized"): with pytest.raises(RuntimeError, match="GPU model not initialized"):
TTSGPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0) TTSGPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
def test_gpu_process_text(): def test_gpu_process_text():
"""Test GPU process_text functionality""" """Test GPU process_text functionality"""
with patch('api.src.services.tts_gpu.phonemize') as mock_phonemize, \ with patch("api.src.services.tts_gpu.phonemize") as mock_phonemize, patch(
patch('api.src.services.tts_gpu.tokenize') as mock_tokenize: "api.src.services.tts_gpu.tokenize"
) as mock_tokenize:
mock_phonemize.return_value = "test phonemes" mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3] mock_tokenize.return_value = [1, 2, 3]

View file

@ -9,10 +9,10 @@ import pytest
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from api.src.core.config import settings from api.src.core.config import settings
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel from api.src.services.tts_gpu import TTSGPUModel
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
@pytest.fixture @pytest.fixture
@ -25,8 +25,13 @@ def tts_service(monkeypatch):
# Set up model instance # Set up model instance
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model) monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_instance", MagicMock(return_value=mock_model)) monkeypatch.setattr(
monkeypatch.setattr("api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu")) "api.src.services.tts_model.TTSModel.get_instance",
MagicMock(return_value=mock_model),
)
monkeypatch.setattr(
"api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu")
)
return TTSService() return TTSService()
@ -51,6 +56,7 @@ def test_audio_to_bytes(tts_service, sample_audio):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_voices(tts_service): async def test_list_voices(tts_service):
"""Test listing available voices""" """Test listing available voices"""
# Override list_voices for testing # Override list_voices for testing
# # TODO: # # TODO:
# Whatever aiofiles does here pathing aiofiles vs aiofiles.os # Whatever aiofiles does here pathing aiofiles vs aiofiles.os
@ -58,6 +64,7 @@ async def test_list_voices(tts_service):
# Cheating the test as it seems to work in the real world (for now) # Cheating the test as it seems to work in the real world (for now)
async def mock_list_voices(): async def mock_list_voices():
return ["voice1", "voice2"] return ["voice1", "voice2"]
tts_service.list_voices = mock_list_voices tts_service.list_voices = mock_list_voices
voices = await tts_service.list_voices() voices = await tts_service.list_voices()
@ -69,10 +76,12 @@ async def test_list_voices(tts_service):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_voices_error(tts_service): async def test_list_voices_error(tts_service):
"""Test error handling in list_voices""" """Test error handling in list_voices"""
# Override list_voices for testing # Override list_voices for testing
# TODO: See above. # TODO: See above.
async def mock_list_voices(): async def mock_list_voices():
return [] return []
tts_service.list_voices = mock_list_voices tts_service.list_voices = mock_list_voices
voices = await tts_service.list_voices() voices = await tts_service.list_voices()
@ -124,10 +133,11 @@ def test_generate_audio_empty_text(tts_service):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_settings(): def mock_settings():
"""Mock settings for all tests""" """Mock settings for all tests"""
with patch('api.src.services.text_processing.chunker.settings') as mock_settings: with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
mock_settings.max_chunk_size = 300 mock_settings.max_chunk_size = 300
yield mock_settings yield mock_settings
@patch("api.src.services.tts_model.TTSModel.get_instance") @patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device") @patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists") @patch("os.path.exists")
@ -150,7 +160,10 @@ def test_generate_audio_phonemize_error(
"""Test handling phonemization error""" """Test handling phonemization error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed") mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency mock_instance.return_value = (
mock_generate,
"cpu",
) # Use the same mock for consistency
mock_get_device.return_value = "cpu" mock_get_device.return_value = "cpu"
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000)) mock_torch_load.return_value = torch.zeros((10, 24000))
@ -184,7 +197,10 @@ def test_generate_audio_error(
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = [1, 2] # Return integers instead of strings mock_tokenize.return_value = [1, 2] # Return integers instead of strings
mock_generate.side_effect = Exception("Generation failed") mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency mock_instance.return_value = (
mock_generate,
"cpu",
) # Use the same mock for consistency
mock_get_device.return_value = "cpu" mock_get_device.return_value = "cpu"
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000)) mock_torch_load.return_value = torch.zeros((10, 24000))
@ -205,12 +221,11 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
async def test_combine_voices(tts_service): async def test_combine_voices(tts_service):
"""Test combining multiple voices""" """Test combining multiple voices"""
# Setup mocks for torch operations # Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \ with patch("torch.load", return_value=torch.tensor([1.0, 2.0])), patch(
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \ "torch.stack", return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \ ), patch("torch.mean", return_value=torch.tensor([2.0, 3.0])), patch(
patch('torch.save'), \ "torch.save"
patch('os.path.exists', return_value=True): ), patch("os.path.exists", return_value=True):
# Test combining two voices # Test combining two voices
result = await tts_service.combine_voices(["voice1", "voice2"]) result = await tts_service.combine_voices(["voice1", "voice2"])

View file

@ -166,7 +166,7 @@ def measure_first_token_openai(
def main(): def main():
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
prefix='cpu' prefix = "cpu"
# Run requests benchmark # Run requests benchmark
print("\n=== Running Direct Requests Benchmark ===") print("\n=== Running Direct Requests Benchmark ===")
run_benchmark( run_benchmark(
@ -176,7 +176,7 @@ def main():
output_plots_dir=os.path.join(script_dir, "output_plots"), output_plots_dir=os.path.join(script_dir, "output_plots"),
suffix="_stream", suffix="_stream",
plot_title_suffix="(Streaming)", plot_title_suffix="(Streaming)",
prefix=prefix prefix=prefix,
) )
# Run OpenAI benchmark # Run OpenAI benchmark
print("\n=== Running OpenAI Library Benchmark ===") print("\n=== Running OpenAI Library Benchmark ===")
@ -187,7 +187,7 @@ def main():
output_plots_dir=os.path.join(script_dir, "output_plots"), output_plots_dir=os.path.join(script_dir, "output_plots"),
suffix="_stream_openai", suffix="_stream_openai",
plot_title_suffix="(OpenAI Streaming)", plot_title_suffix="(OpenAI Streaming)",
prefix=prefix prefix=prefix,
) )

View file

@ -149,19 +149,19 @@ def run_benchmark(
result["run_number"] = i + 1 result["run_number"] = i + 1
# Handle time to first audio # Handle time to first audio
first_chunk = result.get('time_to_first_chunk') first_chunk = result.get("time_to_first_chunk")
print( print(
f"Time to First Audio: {f'{first_chunk:.3f}s' if first_chunk is not None else 'N/A'}" f"Time to First Audio: {f'{first_chunk:.3f}s' if first_chunk is not None else 'N/A'}"
) )
# Handle total time # Handle total time
total_time = result.get('total_time') total_time = result.get("total_time")
print( print(
f"Time to Save Complete: {f'{total_time:.3f}s' if total_time is not None else 'N/A'}" f"Time to Save Complete: {f'{total_time:.3f}s' if total_time is not None else 'N/A'}"
) )
# Handle audio length # Handle audio length
audio_length = result.get('audio_length') audio_length = result.get("audio_length")
print( print(
f"Audio length: {f'{audio_length:.3f}s' if audio_length is not None else 'N/A'}" f"Audio length: {f'{audio_length:.3f}s' if audio_length is not None else 'N/A'}"
) )
@ -191,10 +191,18 @@ def run_benchmark(
# Print paths # Print paths
print("\nResults and plots saved to:") print("\nResults and plots saved to:")
print(f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}") print(
print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}") f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}"
print(f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}") )
print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}") print(
f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}"
)
print(
f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}"
)
print(
f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}"
)
# Print silence check summary # Print silence check summary
if silent_files: if silent_files:

View file

@ -1,42 +1,39 @@
#!/usr/bin/env rye run python #!/usr/bin/env rye run python
# %%
import time import time
from pathlib import Path from pathlib import Path
from openai import OpenAI from openai import OpenAI
# gets OPENAI_API_KEY from your environment variables # gets OPENAI_API_KEY from your environment variables
openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local") openai = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed-for-local")
speech_file_path = Path(__file__).parent / "speech.mp3" speech_file_path = Path(__file__).parent / "speech.mp3"
def main() -> None: def main() -> None:
stream_to_speakers() stream_to_speakers()
# Create text-to-speech audio file
with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af",
input="the quick brown fox jumped over the lazy dogs",
) as response:
response.stream_to_file(speech_file_path)
def stream_to_speakers() -> None: def stream_to_speakers() -> None:
import pyaudio import pyaudio
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True) player_stream = pyaudio.PyAudio().open(
format=pyaudio.paInt16, channels=1, rate=24000, output=True
)
start_time = time.time() start_time = time.time()
with openai.audio.speech.with_streaming_response.create( with openai.audio.speech.with_streaming_response.create(
model="kokoro", model="kokoro",
voice="af_0p0_n2p0", voice=VOICE,
response_format="pcm", # similar to WAV, but without a header chunk at the start. response_format="mp3", # similar to WAV, but without a header chunk at the start.
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earths surface""", input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension""",
) as response: ) as response:
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms") print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
for chunk in response.iter_bytes(chunk_size=1024): for chunk in response.iter_bytes(chunk_size=1024):
@ -47,3 +44,5 @@ def stream_to_speakers() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# %%

View file

@ -1,11 +1,13 @@
import requests
import json import json
from pathlib import Path
from typing import Tuple, Optional from typing import Tuple, Optional
from pathlib import Path
import requests
# Get the directory this script is in # Get the directory this script is in
SCRIPT_DIR = Path(__file__).parent.absolute() SCRIPT_DIR = Path(__file__).parent.absolute()
def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]: def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
"""Get phonemes and tokens for input text. """Get phonemes and tokens for input text.
@ -17,16 +19,10 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
Tuple of (phonemes string, token list) Tuple of (phonemes string, token list)
""" """
# Create the request payload # Create the request payload
payload = { payload = {"text": text, "language": language}
"text": text,
"language": language
}
# Make POST request to the phonemize endpoint # Make POST request to the phonemize endpoint
response = requests.post( response = requests.post("http://localhost:8880/text/phonemize", json=payload)
"http://localhost:8880/text/phonemize",
json=payload
)
# Raise exception for error status codes # Raise exception for error status codes
response.raise_for_status() response.raise_for_status()
@ -35,7 +31,10 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
result = response.json() result = response.json()
return result["phonemes"], result["tokens"] return result["phonemes"], result["tokens"]
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed: float = 1.0) -> Optional[bytes]:
def generate_audio_from_phonemes(
phonemes: str, voice: str = "af_bella", speed: float = 1.0
) -> Optional[bytes]:
"""Generate audio from phonemes. """Generate audio from phonemes.
Args: Args:
@ -47,16 +46,11 @@ def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed:
WAV audio bytes if successful, None if failed WAV audio bytes if successful, None if failed
""" """
# Create the request payload # Create the request payload
payload = { payload = {"phonemes": phonemes, "voice": voice, "speed": speed}
"phonemes": phonemes,
"voice": voice,
"speed": speed
}
# Make POST request to generate audio # Make POST request to generate audio
response = requests.post( response = requests.post(
"http://localhost:8880/text/generate_from_phonemes", "http://localhost:8880/text/generate_from_phonemes", json=payload
json=payload
) )
# Raise exception for error status codes # Raise exception for error status codes
@ -64,6 +58,7 @@ def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella", speed:
return response.content return response.content
def main(): def main():
# Example texts to convert # Example texts to convert
examples = [ examples = [
@ -71,7 +66,7 @@ def main():
"How are you today? I am doing reasonably well, thank you for asking", "How are you today? I am doing reasonably well, thank you for asking",
"""This is a test of the phoneme generation system. Do not be alarmed. """This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, ' This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area.""" you would be instructed to a phoneme shelter in your area.""",
] ]
print("Generating phonemes and audio for example texts...\n") print("Generating phonemes and audio for example texts...\n")
@ -104,5 +99,6 @@ def main():
except requests.RequestException as e: except requests.RequestException as e:
print(f"Error: {e}\n") print(f"Error: {e}\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -1,11 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import requests
import numpy as np
import sounddevice as sd
import time
import os import os
import time
import wave import wave
import numpy as np
import requests
import sounddevice as sd
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
"""Stream TTS audio and play it back in real-time""" """Stream TTS audio and play it back in real-time"""
@ -26,7 +28,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
channels=1, channels=1,
dtype=np.int16, dtype=np.int16,
blocksize=1024, # Buffer size in samples blocksize=1024, # Buffer size in samples
latency='low' # Request low latency latency="low", # Request low latency
) )
stream.start() stream.start()
@ -39,16 +41,18 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
"input": text, "input": text,
"voice": voice, "voice": voice,
"response_format": "pcm", "response_format": "pcm",
"stream": True "stream": True,
}, },
stream=True, stream=True,
timeout=1800 timeout=1800,
) )
response.raise_for_status() response.raise_for_status()
print(f"Request started successfully after {time.time() - start_time:.2f}s") print(f"Request started successfully after {time.time() - start_time:.2f}s")
# Process streaming response with smaller chunks for lower latency # Process streaming response with smaller chunks for lower latency
for chunk in response.iter_content(chunk_size=512): # 512 bytes = 256 samples at 16-bit for chunk in response.iter_content(
chunk_size=512
): # 512 bytes = 256 samples at 16-bit
if chunk: if chunk:
chunk_count += 1 chunk_count += 1
total_bytes += len(chunk) total_bytes += len(chunk)
@ -56,7 +60,9 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
# Handle first chunk # Handle first chunk
if not audio_started: if not audio_started:
first_chunk_time = time.time() first_chunk_time = time.time()
print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s") print(
f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s"
)
print(f"First chunk size: {len(chunk)} bytes") print(f"First chunk size: {len(chunk)} bytes")
audio_started = True audio_started = True
@ -70,7 +76,9 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
# Log progress every 10 chunks # Log progress every 10 chunks
if chunk_count % 10 == 0: if chunk_count % 10 == 0:
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed") print(
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"
)
# Final stats # Final stats
total_time = time.time() - start_time total_time = time.time() - start_time
@ -83,7 +91,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
# Save as WAV file # Save as WAV file
if output_file: if output_file:
print(f"\nWriting audio to {output_file}") print(f"\nWriting audio to {output_file}")
with wave.open(output_file, 'wb') as wav_file: with wave.open(output_file, "wb") as wav_file:
wav_file.setnchannels(1) # Mono wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit) wav_file.setsampwidth(2) # 2 bytes per sample (16-bit)
wav_file.setframerate(sample_rate) wav_file.setframerate(sample_rate)
@ -103,10 +111,13 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
stream.stop() stream.stop()
stream.close() stream.close()
def main(): def main():
# Load sample text from HG Wells # Load sample text from HG Wells
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
wells_path = os.path.join(script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt") wells_path = os.path.join(
script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt"
)
output_path = os.path.join(script_dir, "output.wav") output_path = os.path.join(script_dir, "output.wav")
with open(wells_path, "r", encoding="utf-8") as f: with open(wells_path, "r", encoding="utf-8") as f:
@ -121,5 +132,6 @@ def main():
play_streaming_tts(text, output_file=output_path) play_streaming_tts(text, output_file=output_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()