Inital test commit of segfault fixes

This commit is contained in:
Fireblade2534 2025-03-20 16:20:28 +00:00
parent 0d7570ab50
commit 8f23bf53a4
6 changed files with 442 additions and 198 deletions

283
Test Threads.py Normal file
View file

@ -0,0 +1,283 @@
#!/usr/bin/env python3
# Compatible with both Windows and Linux
"""
Kokoro TTS Race Condition Test
This script creates multiple concurrent requests to a Kokoro TTS service
to reproduce a race condition where audio outputs don't match the requested text.
Each thread generates a simple numbered sentence, which should make mismatches
easy to identify through listening.
To run:
python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880
"""
import argparse
import base64
import concurrent.futures
import json
import os
import requests
import time
import wave
import sys
from pathlib import Path
def setup_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
parser.add_argument("--url", default="http://localhost:8880",
help="Base URL of the Kokoro TTS service")
parser.add_argument("--threads", type=int, default=8,
help="Number of concurrent threads to use")
parser.add_argument("--iterations", type=int, default=5,
help="Number of iterations per thread")
parser.add_argument("--voice", default="af_heart",
help="Voice to use for TTS")
parser.add_argument("--output-dir", default="./tts_test_output",
help="Directory to save output files")
parser.add_argument("--debug", action="store_true",
help="Enable debug logging")
return parser.parse_args()
def generate_test_sentence(thread_id, iteration):
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
return f"This is test sentence number {thread_id}-{iteration}. " \
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
def log_message(message, debug=False, is_error=False):
"""Log messages with timestamps"""
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
prefix = "[ERROR]" if is_error else "[INFO]"
if is_error or debug:
print(f"{prefix} {timestamp} - {message}")
sys.stdout.flush() # Ensure logs are visible in Docker output
def request_tts(url, test_id, text, voice, output_dir, debug=False):
"""Request TTS from the Kokoro API and save the WAV output"""
start_time = time.time()
output_file = os.path.join(output_dir, f"test_{test_id}.wav")
text_file = os.path.join(output_dir, f"test_{test_id}.txt")
# Log output paths for debugging
log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
# Save the text for later comparison
try:
with open(text_file, "w") as f:
f.write(text)
log_message(f"Thread {test_id}: Successfully saved text file", debug)
except Exception as e:
log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True)
# Make the TTS request
try:
log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
response = requests.post(
f"{url}/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": voice,
"response_format": "wav"
},
headers={"Accept": "audio/wav"},
timeout=60 # Increase timeout to 60 seconds
)
log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug)
log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug)
log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug)
if response.status_code != 200:
log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True)
return False
# Check if we got valid audio data
if len(response.content) < 100: # Sanity check - WAV files should be larger than this
log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True)
log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True)
return False
# Save the audio output with explicit error handling
try:
with open(output_file, "wb") as f:
bytes_written = f.write(response.content)
log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug)
# Verify the WAV file exists and has content
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug)
# Validate WAV file by reading its headers
try:
with wave.open(output_file, 'rb') as wav_file:
channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
framerate = wav_file.getframerate()
frames = wav_file.getnframes()
log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, "
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug)
except Exception as wav_error:
log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True)
else:
log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True)
except Exception as save_error:
log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True)
return False
end_time = time.time()
log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug)
return True
except requests.exceptions.Timeout:
log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
return False
except Exception as e:
log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True)
return False
def worker_task(thread_id, args):
"""Worker task for each thread"""
for i in range(args.iterations):
iteration = i + 1
test_id = f"{thread_id:02d}_{iteration:02d}"
text = generate_test_sentence(thread_id, iteration)
success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug)
if not success:
log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True)
# Small delay between iterations to avoid overwhelming the API
time.sleep(0.1)
def run_test(args):
"""Run the test with the specified parameters"""
# Ensure output directory exists and check permissions
os.makedirs(args.output_dir, exist_ok=True)
# Test write access to the output directory
test_file = os.path.join(args.output_dir, "write_test.txt")
try:
with open(test_file, "w") as f:
f.write("Testing write access\n")
os.remove(test_file)
log_message(f"Successfully verified write access to output directory: {args.output_dir}")
except Exception as e:
log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True)
log_message(f"Current directory: {os.getcwd()}", is_error=True)
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
# Test connection to Kokoro TTS service
try:
response = requests.get(f"{args.url}/health", timeout=5)
if response.status_code == 200:
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
else:
log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True)
except Exception as e:
log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True)
# Record start time
start_time = time.time()
log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread")
# Create and start worker threads
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
for thread_id in range(1, args.threads + 1):
futures.append(executor.submit(worker_task, thread_id, args))
# Wait for all tasks to complete
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True)
# Record end time and print summary
end_time = time.time()
total_time = end_time - start_time
total_requests = args.threads * args.iterations
log_message(f"Test completed in {total_time:.2f} seconds")
log_message(f"Total requests: {total_requests}")
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
log_message(f"Requests per second: {total_requests / total_time:.2f}")
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
log_message("To verify, listen to the audio files and check if they match the text files")
log_message("If you hear audio describing a different test number than the filename, you've found a race condition")
def analyze_audio_files(output_dir):
"""Provide summary of the generated audio files"""
# Look for both WAV and TXT files
wav_files = list(Path(output_dir).glob("*.wav"))
txt_files = list(Path(output_dir).glob("*.txt"))
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
if len(wav_files) == 0:
log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True)
log_message("Check the connection to the TTS service and the response status codes above.", is_error=True)
file_stats = []
for wav_path in wav_files:
try:
with wave.open(str(wav_path), 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / rate
# Get corresponding text
text_path = wav_path.with_suffix('.txt')
if text_path.exists():
with open(text_path, 'r') as text_file:
text = text_file.read().strip()
else:
text = "N/A"
file_stats.append({
'filename': wav_path.name,
'duration': duration,
'text': text
})
except Exception as e:
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
# Print summary table
if file_stats:
log_message("\nAudio File Summary:")
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
log_message("-" * 92)
for stat in file_stats:
log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}")
# List missing WAV files where text files exist
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
if missing_wavs:
log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True)
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
if len(missing_wavs) > 10:
log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True)
if __name__ == "__main__":
args = setup_args()
run_test(args)
analyze_audio_files(args.output_dir)
log_message("\nNext Steps:")
log_message("1. Listen to the generated audio files")
log_message("2. Verify if each audio correctly says its ID number")
log_message("3. Check for any mismatches between the audio content and the text files")
log_message("4. If mismatches are found, you've successfully reproduced the race condition")

View file

@ -180,10 +180,11 @@ async def create_captioned_speech(
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request)
generator = stream_audio_chunks(tts_service, request, client_request, writer)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@ -284,6 +285,7 @@ async def create_captioned_speech(
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
writer=writer,
speed=request.speed,
return_timestamps=request.return_timestamps,
normalization_options=request.normalization_options,
@ -294,6 +296,7 @@ async def create_captioned_speech(
audio_data,
24000,
request.response_format,
writer,
is_first_chunk=True,
is_last_chunk=False,
trim_audio=False,
@ -304,6 +307,7 @@ async def create_captioned_speech(
AudioChunk(np.array([], dtype=np.int16)),
24000,
request.response_format,
writer,
is_first_chunk=False,
is_last_chunk=True,
)

View file

@ -10,11 +10,12 @@ from urllib import response
import aiofiles
import numpy as np
from ..services.streaming_audio_writer import StreamingAudioWriter
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from structures.schemas import CaptionedSpeechRequest
from ..structures.schemas import CaptionedSpeechRequest
from ..core.config import settings
from ..inference.base import AudioChunk
@ -79,9 +80,7 @@ def get_model_name(model: str) -> str:
return base_name + ".pth"
async def process_and_validate_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
"""Process voice input, handling both string and list formats
Returns:
@ -90,73 +89,59 @@ async def process_and_validate_voices(
voices = []
# Convert input to list of voices
if isinstance(voice_input, str):
voice_input=voice_input.replace(" ","").strip()
voice_input = voice_input.replace(" ", "").strip()
if voice_input[-1] in "+-" or voice_input[0] in "+-":
raise ValueError(
f"Voice combination contains empty combine items"
)
raise ValueError(f"Voice combination contains empty combine items")
if re.search(r"[+-]{2,}", voice_input) is not None:
raise ValueError(
f"Voice combination contains empty combine items"
)
raise ValueError(f"Voice combination contains empty combine items")
voices = re.split(r"([-+])", voice_input)
else:
voices = [[item,"+"] for item in voice_input][:-1]
voices = [[item, "+"] for item in voice_input][:-1]
available_voices = await tts_service.list_voices()
for voice_index in range(0,len(voices), 2):
for voice_index in range(0, len(voices), 2):
mapped_voice = voices[voice_index].split("(")
mapped_voice = list(map(str.strip, mapped_voice))
if len(mapped_voice) > 2:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
if mapped_voice.count(")") > 1:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
if mapped_voice[0] not in available_voices:
raise ValueError(
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}")
voices[voice_index] = "(".join(mapped_voice)
return "".join(voices)
async def stream_audio_chunks(
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[AudioChunk, None]:
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_and_validate_voices(request.voice, tts_service)
unique_properties={
"return_timestamps":False
}
unique_properties = {"return_timestamps": False}
if hasattr(request, "return_timestamps"):
unique_properties["return_timestamps"]=request.return_timestamps
unique_properties["return_timestamps"] = request.return_timestamps
try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk_data in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
writer=writer,
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"],
):
# Check if client is still connected
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
@ -174,7 +159,6 @@ async def stream_audio_chunks(
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
x_raw_response: str = Header(None, alias="x-raw-response"),
@ -206,10 +190,12 @@ async def create_speech(
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request)
generator = stream_audio_chunks(tts_service, request, client_request, writer)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@ -239,9 +225,9 @@ async def create_speech(
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
#if return_json:
# if return_json:
# yield chunk, chunk_data
#else:
# else:
yield chunk_data.output
# Finalize the temp file
@ -256,9 +242,7 @@ async def create_speech(
await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing
return StreamingResponse(
dual_output(), media_type=content_type, headers=headers
)
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
async def single_output():
try:
@ -269,7 +253,7 @@ async def create_speech(
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
raise
# Standard streaming without download link
return StreamingResponse(
single_output(),
@ -283,40 +267,36 @@ async def create_speech(
)
else:
headers = {
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"Cache-Control": "no-cache", # Prevent caching
}
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"Cache-Control": "no-cache", # Prevent caching
}
# Generate complete audio using public interface
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
writer=writer,
speed=request.speed,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(
audio_data,
24000,
request.response_format,
is_first_chunk=True,
is_last_chunk=False,
trim_audio=False
)
audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False)
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
24000,
request.response_format,
writer,
is_first_chunk=False,
is_last_chunk=True,
)
output=audio_data.output + final.output
output = audio_data.output + final.output
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
# Use download_format if specified, otherwise use response_format
output_format = request.download_format or request.response_format
temp_writer = TempFileWriter(output_format)
@ -390,9 +370,7 @@ async def download_audio_file(filename: str):
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(
filename=filename, search_paths=[settings.temp_file_dir]
)
file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir])
# Get content type from path helper
content_type = await get_content_type(file_path)
@ -425,30 +403,12 @@ async def list_models():
try:
# Create standard model list
models = [
{
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
{
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
}
{"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
{"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
{"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
]
return {
"object": "list",
"data": models
}
return {"object": "list", "data": models}
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(
@ -460,43 +420,22 @@ async def list_models():
},
)
@router.get("/models/{model}")
async def retrieve_model(model: str):
"""Retrieve a specific model"""
try:
# Define available models
models = {
"tts-1": {
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
"tts-1-hd": {
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
"kokoro": {
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
}
"tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
"tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
"kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
}
# Check if requested model exists
if model not in models:
raise HTTPException(
status_code=404,
detail={
"error": "model_not_found",
"message": f"Model '{model}' not found",
"type": "invalid_request_error"
}
)
raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"})
# Return the specific model
return models[model]
except HTTPException:
@ -512,6 +451,7 @@ async def retrieve_model(model: str):
},
)
@router.get("/audio/voices")
async def list_voices():
"""List all available voices for text-to-speech"""
@ -579,9 +519,7 @@ async def combine_voices(request: Union[str, List[str]]):
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
# Combine voices
combined_tensor = await tts_service.combine_voices(voices=voices)

View file

@ -115,6 +115,7 @@ class AudioService:
audio_chunk: AudioChunk,
sample_rate: int,
output_format: str,
writer: StreamingAudioWriter,
speed: float = 1,
chunk_text: str = "",
is_first_chunk: bool = True,
@ -153,28 +154,30 @@ class AudioService:
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
# Get or create format-specific writer
writer_key = f"{output_format}_{sample_rate}"
"""writer_key = f"{output_format}_{sample_rate}"
if is_first_chunk or writer_key not in AudioService._writers:
AudioService._writers[writer_key] = StreamingAudioWriter(
output_format, sample_rate
)
writer = AudioService._writers[writer_key]
writer = AudioService._writers[writer_key]"""
# Write audio data first
if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(audio_chunk.audio)
# Then finalize if this is the last chunk
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
del AudioService._writers[writer_key]
if final_data:
audio_chunk.output=final_data
return audio_chunk
if chunk_data:
audio_chunk.output=chunk_data
audio_chunk.output=chunk_data
return audio_chunk
except Exception as e:

View file

@ -8,6 +8,7 @@ import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
import numpy as np
from .streaming_audio_writer import StreamingAudioWriter
import torch
from kokoro import KPipeline
from loguru import logger
@ -50,6 +51,7 @@ class TTSService:
voice_name: str,
voice_path: str,
speed: float,
writer: StreamingAudioWriter,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
@ -64,12 +66,13 @@ class TTSService:
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
yield AudioChunk(np.array([], dtype=np.int16),output=b'')
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
return
chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
24000,
output_format,
writer,
speed,
"",
is_first_chunk=False,
@ -88,7 +91,7 @@ class TTSService:
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
chunk_index=0
chunk_index = 0
# For Kokoro V1, pass text and voice info with lang_code
async for chunk_data in self.model_manager.generate(
chunk_text,
@ -104,6 +107,7 @@ class TTSService:
chunk_data,
24000,
output_format,
writer,
speed,
chunk_text,
is_first_chunk=is_first and chunk_index == 0,
@ -114,22 +118,13 @@ class TTSService:
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
chunk_data = AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
yield chunk_data
chunk_index+=1
chunk_index += 1
else:
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
chunk_data = await self.model_manager.generate(
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
)
voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device)
chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps)
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
@ -146,6 +141,7 @@ class TTSService:
chunk_data,
24000,
output_format,
writer,
speed,
chunk_text,
is_first_chunk=is_first,
@ -156,11 +152,7 @@ class TTSService:
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
trimmed = AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
yield trimmed
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
@ -169,7 +161,7 @@ class TTSService:
# Check if the path is None and raise a ValueError if it is not
if not path:
raise ValueError(f"Voice not found at path: {path}")
logger.debug(f"Loading voice tensor from path: {path}")
return torch.load(path, map_location="cpu") * weight
@ -191,10 +183,8 @@ class TTSService:
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
@ -202,8 +192,8 @@ class TTSService:
return voice, path
total_weight = 0
for voice_index in range(0,len(split_voice),2):
for voice_index in range(0, len(split_voice), 2):
voice_object = split_voice[voice_index]
if "(" in voice_object and ")" in voice_object:
@ -215,7 +205,7 @@ class TTSService:
total_weight += voice_weight
split_voice[voice_index] = (voice_name, voice_weight)
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
@ -225,9 +215,9 @@ class TTSService:
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1,len(split_voice) - 1, 2):
for operation_index in range(1, len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0])
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
# Either add or subtract the voice from the current combined voice
@ -235,7 +225,7 @@ class TTSService:
combined_tensor += voice_tensor
else:
combined_tensor -= voice_tensor
# Save the new combined voice so it can be loaded latter
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, f"{voice}.pt")
@ -250,6 +240,7 @@ class TTSService:
self,
text: str,
voice: str,
writer: StreamingAudioWriter,
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
@ -259,7 +250,7 @@ class TTSService:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
chunk_index = 0
current_offset=0.0
current_offset = 0.0
try:
# Get backend
backend = self.model_manager.get_backend()
@ -270,13 +261,10 @@ class TTSService:
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options):
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
@ -285,6 +273,7 @@ class TTSService:
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
@ -294,23 +283,19 @@ class TTSService:
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time+=current_offset
timestamp.end_time+=current_offset
current_offset+=len(chunk_data.audio) / 24000
timestamp.start_time += current_offset
timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
chunk_index += 1
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
continue
# Only finalize if we successfully processed at least one chunk
@ -323,6 +308,7 @@ class TTSService:
voice_name,
voice_path,
speed,
writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
@ -337,32 +323,30 @@ class TTSService:
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise e
async def generate_audio(
self,
text: str,
voice: str,
writer: StreamingAudioWriter,
speed: float = 1.0,
return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None,
) -> AudioChunk:
"""Generate complete audio for text using streaming internally."""
audio_data_chunks=[]
audio_data_chunks = []
try:
async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
combined_audio_data=AudioChunk.combine(audio_data_chunks)
combined_audio_data = AudioChunk.combine(audio_data_chunks)
return combined_audio_data
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
"""Combine multiple voices.
@ -406,15 +390,11 @@ class TTSService:
result = None
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
)
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
try:
# Use backend's pipeline management
for r in backend._get_pipeline(
pipeline_lang_code
).generate_from_tokens(
for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
speed=speed,
@ -432,9 +412,7 @@ class TTSService:
processing_time = time.time() - start_time
return result.audio.numpy(), processing_time
else:
raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend"
)
raise ValueError("Phoneme generation only supported with Kokoro V1 backend")
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")

View file

@ -7,7 +7,7 @@ import pytest
from api.src.services.audio import AudioNormalizer, AudioService
from api.src.inference.base import AudioChunk
from api.src.services.streaming_audio_writer import StreamingAudioWriter
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
@ -30,9 +30,11 @@ def sample_audio():
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=False
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -46,8 +48,11 @@ async def test_convert_to_wav(sample_audio):
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("writer", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "mp3"
AudioChunk(audio_data), sample_rate, "mp3", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -60,8 +65,11 @@ async def test_convert_to_mp3(sample_audio):
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("opus", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "opus"
AudioChunk(audio_data), sample_rate, "opus",writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -74,8 +82,11 @@ async def test_convert_to_opus(sample_audio):
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("flac", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "flac"
AudioChunk(audio_data), sample_rate, "flac", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -88,8 +99,11 @@ async def test_convert_to_flac(sample_audio):
async def test_convert_to_aac(sample_audio):
"""Test converting to M4A format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("aac", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "aac"
AudioChunk(audio_data), sample_rate, "aac", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -102,8 +116,11 @@ async def test_convert_to_aac(sample_audio):
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("pcm", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "pcm"
AudioChunk(audio_data), sample_rate, "pcm", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -115,19 +132,25 @@ async def test_convert_to_pcm(sample_audio):
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Unsupported format: invalid"):
StreamingAudioWriter("invalid", sample_rate=24000)
with pytest.raises(ValueError, match="Format invalid not supported"):
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
await AudioService.convert_audio(audio_data, sample_rate, "invalid", None)
@pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Create audio data outside int16 range
large_audio = audio_data * 1e5
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True
AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -138,10 +161,13 @@ async def test_normalization_wav(sample_audio):
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("pcm", sample_rate=24000)
# Create audio data outside int16 range
large_audio = audio_data * 1e5
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "pcm"
AudioChunk(large_audio), sample_rate, "pcm", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -153,8 +179,11 @@ async def test_invalid_audio_data():
"""Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array
sample_rate = 24000
writer = StreamingAudioWriter("wav", sample_rate=24000)
with pytest.raises(ValueError):
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
@pytest.mark.asyncio
@ -164,8 +193,11 @@ async def test_different_sample_rates(sample_audio):
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
writer = StreamingAudioWriter("wav", sample_rate=rate)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), rate, "wav", is_first_chunk=True
AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@ -176,15 +208,21 @@ async def test_different_sample_rates(sample_audio):
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Write and finalize in one step for first conversion
audio_chunk1 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
)
assert isinstance(audio_chunk1.output, bytes)
assert isinstance(audio_chunk1, AudioChunk)
# Convert again to ensure buffer was properly reset
writer = StreamingAudioWriter("wav", sample_rate=24000)
audio_chunk2 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
)
assert isinstance(audio_chunk2.output, bytes)
assert isinstance(audio_chunk2, AudioChunk)