mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Inital test commit of segfault fixes
This commit is contained in:
parent
0d7570ab50
commit
8f23bf53a4
6 changed files with 442 additions and 198 deletions
283
Test Threads.py
Normal file
283
Test Threads.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
#!/usr/bin/env python3
|
||||
# Compatible with both Windows and Linux
|
||||
"""
|
||||
Kokoro TTS Race Condition Test
|
||||
|
||||
This script creates multiple concurrent requests to a Kokoro TTS service
|
||||
to reproduce a race condition where audio outputs don't match the requested text.
|
||||
Each thread generates a simple numbered sentence, which should make mismatches
|
||||
easy to identify through listening.
|
||||
|
||||
To run:
|
||||
python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
import wave
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
||||
parser.add_argument("--url", default="http://localhost:8880",
|
||||
help="Base URL of the Kokoro TTS service")
|
||||
parser.add_argument("--threads", type=int, default=8,
|
||||
help="Number of concurrent threads to use")
|
||||
parser.add_argument("--iterations", type=int, default=5,
|
||||
help="Number of iterations per thread")
|
||||
parser.add_argument("--voice", default="af_heart",
|
||||
help="Voice to use for TTS")
|
||||
parser.add_argument("--output-dir", default="./tts_test_output",
|
||||
help="Directory to save output files")
|
||||
parser.add_argument("--debug", action="store_true",
|
||||
help="Enable debug logging")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def generate_test_sentence(thread_id, iteration):
|
||||
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
||||
return f"This is test sentence number {thread_id}-{iteration}. " \
|
||||
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
||||
|
||||
|
||||
def log_message(message, debug=False, is_error=False):
|
||||
"""Log messages with timestamps"""
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
prefix = "[ERROR]" if is_error else "[INFO]"
|
||||
if is_error or debug:
|
||||
print(f"{prefix} {timestamp} - {message}")
|
||||
sys.stdout.flush() # Ensure logs are visible in Docker output
|
||||
|
||||
|
||||
def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
||||
"""Request TTS from the Kokoro API and save the WAV output"""
|
||||
start_time = time.time()
|
||||
output_file = os.path.join(output_dir, f"test_{test_id}.wav")
|
||||
text_file = os.path.join(output_dir, f"test_{test_id}.txt")
|
||||
|
||||
# Log output paths for debugging
|
||||
log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
|
||||
log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
|
||||
|
||||
# Save the text for later comparison
|
||||
try:
|
||||
with open(text_file, "w") as f:
|
||||
f.write(text)
|
||||
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
||||
except Exception as e:
|
||||
log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True)
|
||||
|
||||
# Make the TTS request
|
||||
try:
|
||||
log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
|
||||
|
||||
response = requests.post(
|
||||
f"{url}/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"response_format": "wav"
|
||||
},
|
||||
headers={"Accept": "audio/wav"},
|
||||
timeout=60 # Increase timeout to 60 seconds
|
||||
)
|
||||
|
||||
log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug)
|
||||
log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug)
|
||||
log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug)
|
||||
|
||||
if response.status_code != 200:
|
||||
log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True)
|
||||
return False
|
||||
|
||||
# Check if we got valid audio data
|
||||
if len(response.content) < 100: # Sanity check - WAV files should be larger than this
|
||||
log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True)
|
||||
log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True)
|
||||
return False
|
||||
|
||||
# Save the audio output with explicit error handling
|
||||
try:
|
||||
with open(output_file, "wb") as f:
|
||||
bytes_written = f.write(response.content)
|
||||
log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug)
|
||||
|
||||
# Verify the WAV file exists and has content
|
||||
if os.path.exists(output_file):
|
||||
file_size = os.path.getsize(output_file)
|
||||
log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug)
|
||||
|
||||
# Validate WAV file by reading its headers
|
||||
try:
|
||||
with wave.open(output_file, 'rb') as wav_file:
|
||||
channels = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
framerate = wav_file.getframerate()
|
||||
frames = wav_file.getnframes()
|
||||
log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
||||
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug)
|
||||
except Exception as wav_error:
|
||||
log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True)
|
||||
else:
|
||||
log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True)
|
||||
except Exception as save_error:
|
||||
log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True)
|
||||
return False
|
||||
|
||||
end_time = time.time()
|
||||
log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug)
|
||||
return True
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
|
||||
return False
|
||||
except Exception as e:
|
||||
log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True)
|
||||
return False
|
||||
|
||||
|
||||
def worker_task(thread_id, args):
|
||||
"""Worker task for each thread"""
|
||||
for i in range(args.iterations):
|
||||
iteration = i + 1
|
||||
test_id = f"{thread_id:02d}_{iteration:02d}"
|
||||
text = generate_test_sentence(thread_id, iteration)
|
||||
success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug)
|
||||
|
||||
if not success:
|
||||
log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True)
|
||||
|
||||
# Small delay between iterations to avoid overwhelming the API
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def run_test(args):
|
||||
"""Run the test with the specified parameters"""
|
||||
# Ensure output directory exists and check permissions
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Test write access to the output directory
|
||||
test_file = os.path.join(args.output_dir, "write_test.txt")
|
||||
try:
|
||||
with open(test_file, "w") as f:
|
||||
f.write("Testing write access\n")
|
||||
os.remove(test_file)
|
||||
log_message(f"Successfully verified write access to output directory: {args.output_dir}")
|
||||
except Exception as e:
|
||||
log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True)
|
||||
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
||||
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
||||
|
||||
# Test connection to Kokoro TTS service
|
||||
try:
|
||||
response = requests.get(f"{args.url}/health", timeout=5)
|
||||
if response.status_code == 200:
|
||||
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
||||
else:
|
||||
log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True)
|
||||
except Exception as e:
|
||||
log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True)
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread")
|
||||
|
||||
# Create and start worker threads
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
||||
futures = []
|
||||
for thread_id in range(1, args.threads + 1):
|
||||
futures.append(executor.submit(worker_task, thread_id, args))
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True)
|
||||
|
||||
# Record end time and print summary
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
total_requests = args.threads * args.iterations
|
||||
log_message(f"Test completed in {total_time:.2f} seconds")
|
||||
log_message(f"Total requests: {total_requests}")
|
||||
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
||||
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
||||
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
||||
log_message("To verify, listen to the audio files and check if they match the text files")
|
||||
log_message("If you hear audio describing a different test number than the filename, you've found a race condition")
|
||||
|
||||
|
||||
def analyze_audio_files(output_dir):
|
||||
"""Provide summary of the generated audio files"""
|
||||
# Look for both WAV and TXT files
|
||||
wav_files = list(Path(output_dir).glob("*.wav"))
|
||||
txt_files = list(Path(output_dir).glob("*.txt"))
|
||||
|
||||
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
||||
|
||||
if len(wav_files) == 0:
|
||||
log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True)
|
||||
log_message("Check the connection to the TTS service and the response status codes above.", is_error=True)
|
||||
|
||||
file_stats = []
|
||||
for wav_path in wav_files:
|
||||
try:
|
||||
with wave.open(str(wav_path), 'rb') as wav_file:
|
||||
frames = wav_file.getnframes()
|
||||
rate = wav_file.getframerate()
|
||||
duration = frames / rate
|
||||
|
||||
# Get corresponding text
|
||||
text_path = wav_path.with_suffix('.txt')
|
||||
if text_path.exists():
|
||||
with open(text_path, 'r') as text_file:
|
||||
text = text_file.read().strip()
|
||||
else:
|
||||
text = "N/A"
|
||||
|
||||
file_stats.append({
|
||||
'filename': wav_path.name,
|
||||
'duration': duration,
|
||||
'text': text
|
||||
})
|
||||
except Exception as e:
|
||||
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
||||
|
||||
# Print summary table
|
||||
if file_stats:
|
||||
log_message("\nAudio File Summary:")
|
||||
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
||||
log_message("-" * 92)
|
||||
for stat in file_stats:
|
||||
log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}")
|
||||
|
||||
# List missing WAV files where text files exist
|
||||
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
||||
if missing_wavs:
|
||||
log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True)
|
||||
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
||||
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
||||
if len(missing_wavs) > 10:
|
||||
log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args()
|
||||
run_test(args)
|
||||
analyze_audio_files(args.output_dir)
|
||||
|
||||
log_message("\nNext Steps:")
|
||||
log_message("1. Listen to the generated audio files")
|
||||
log_message("2. Verify if each audio correctly says its ID number")
|
||||
log_message("3. Check for any mismatches between the audio content and the text files")
|
||||
log_message("4. If mismatches are found, you've successfully reproduced the race condition")
|
|
@ -180,10 +180,11 @@ async def create_captioned_speech(
|
|||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
|
||||
# Check if streaming is requested (default for OpenAI client)
|
||||
if request.stream:
|
||||
# Create generator but don't start it yet
|
||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
||||
|
||||
# If download link requested, wrap generator with temp file writer
|
||||
if request.return_download_link:
|
||||
|
@ -284,6 +285,7 @@ async def create_captioned_speech(
|
|||
audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
writer=writer,
|
||||
speed=request.speed,
|
||||
return_timestamps=request.return_timestamps,
|
||||
normalization_options=request.normalization_options,
|
||||
|
@ -294,6 +296,7 @@ async def create_captioned_speech(
|
|||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=False,
|
||||
trim_audio=False,
|
||||
|
@ -304,6 +307,7 @@ async def create_captioned_speech(
|
|||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
|
|
|
@ -10,11 +10,12 @@ from urllib import response
|
|||
|
||||
import aiofiles
|
||||
import numpy as np
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from loguru import logger
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
from ..structures.schemas import CaptionedSpeechRequest
|
||||
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
|
@ -79,9 +80,7 @@ def get_model_name(model: str) -> str:
|
|||
return base_name + ".pth"
|
||||
|
||||
|
||||
async def process_and_validate_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
||||
"""Process voice input, handling both string and list formats
|
||||
|
||||
Returns:
|
||||
|
@ -90,73 +89,59 @@ async def process_and_validate_voices(
|
|||
voices = []
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
voice_input=voice_input.replace(" ","").strip()
|
||||
voice_input = voice_input.replace(" ", "").strip()
|
||||
|
||||
if voice_input[-1] in "+-" or voice_input[0] in "+-":
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
raise ValueError(f"Voice combination contains empty combine items")
|
||||
|
||||
if re.search(r"[+-]{2,}", voice_input) is not None:
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
raise ValueError(f"Voice combination contains empty combine items")
|
||||
voices = re.split(r"([-+])", voice_input)
|
||||
else:
|
||||
voices = [[item,"+"] for item in voice_input][:-1]
|
||||
|
||||
voices = [[item, "+"] for item in voice_input][:-1]
|
||||
|
||||
available_voices = await tts_service.list_voices()
|
||||
|
||||
for voice_index in range(0,len(voices), 2):
|
||||
|
||||
for voice_index in range(0, len(voices), 2):
|
||||
mapped_voice = voices[voice_index].split("(")
|
||||
mapped_voice = list(map(str.strip, mapped_voice))
|
||||
|
||||
if len(mapped_voice) > 2:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
|
||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
||||
|
||||
if mapped_voice.count(")") > 1:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
|
||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
||||
|
||||
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
||||
|
||||
if mapped_voice[0] not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
||||
|
||||
voices[voice_index] = "(".join(mapped_voice)
|
||||
|
||||
|
||||
return "".join(voices)
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
|
||||
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
unique_properties={
|
||||
"return_timestamps":False
|
||||
}
|
||||
unique_properties = {"return_timestamps": False}
|
||||
if hasattr(request, "return_timestamps"):
|
||||
unique_properties["return_timestamps"]=request.return_timestamps
|
||||
|
||||
unique_properties["return_timestamps"] = request.return_timestamps
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk_data in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
writer=writer,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
|
||||
normalization_options=request.normalization_options,
|
||||
return_timestamps=unique_properties["return_timestamps"],
|
||||
):
|
||||
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
if callable(is_disconnected):
|
||||
|
@ -174,7 +159,6 @@ async def stream_audio_chunks(
|
|||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
|
@ -206,10 +190,12 @@ async def create_speech(
|
|||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
|
||||
|
||||
# Check if streaming is requested (default for OpenAI client)
|
||||
if request.stream:
|
||||
# Create generator but don't start it yet
|
||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
||||
|
||||
# If download link requested, wrap generator with temp file writer
|
||||
if request.return_download_link:
|
||||
|
@ -239,9 +225,9 @@ async def create_speech(
|
|||
async for chunk_data in generator:
|
||||
if chunk_data.output: # Skip empty chunks
|
||||
await temp_writer.write(chunk_data.output)
|
||||
#if return_json:
|
||||
# if return_json:
|
||||
# yield chunk, chunk_data
|
||||
#else:
|
||||
# else:
|
||||
yield chunk_data.output
|
||||
|
||||
# Finalize the temp file
|
||||
|
@ -256,9 +242,7 @@ async def create_speech(
|
|||
await temp_writer.__aexit__(None, None, None)
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(
|
||||
dual_output(), media_type=content_type, headers=headers
|
||||
)
|
||||
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
||||
|
||||
async def single_output():
|
||||
try:
|
||||
|
@ -269,7 +253,7 @@ async def create_speech(
|
|||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
single_output(),
|
||||
|
@ -283,40 +267,36 @@ async def create_speech(
|
|||
)
|
||||
else:
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
}
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
}
|
||||
|
||||
# Generate complete audio using public interface
|
||||
audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
writer=writer,
|
||||
speed=request.speed,
|
||||
normalization_options=request.normalization_options,
|
||||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=False,
|
||||
trim_audio=False
|
||||
)
|
||||
|
||||
audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
final = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output=audio_data.output + final.output
|
||||
output = audio_data.output + final.output
|
||||
|
||||
if request.return_download_link:
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
|
||||
# Use download_format if specified, otherwise use response_format
|
||||
output_format = request.download_format or request.response_format
|
||||
temp_writer = TempFileWriter(output_format)
|
||||
|
@ -390,9 +370,7 @@ async def download_audio_file(filename: str):
|
|||
from ..core.paths import _find_file, get_content_type
|
||||
|
||||
# Search for file in temp directory
|
||||
file_path = await _find_file(
|
||||
filename=filename, search_paths=[settings.temp_file_dir]
|
||||
)
|
||||
file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir])
|
||||
|
||||
# Get content type from path helper
|
||||
content_type = await get_content_type(file_path)
|
||||
|
@ -425,30 +403,12 @@ async def list_models():
|
|||
try:
|
||||
# Create standard model list
|
||||
models = [
|
||||
{
|
||||
"id": "tts-1",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
{
|
||||
"id": "tts-1-hd",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
{
|
||||
"id": "kokoro",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
}
|
||||
{"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
{"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
{"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models
|
||||
}
|
||||
|
||||
return {"object": "list", "data": models}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
@ -460,43 +420,22 @@ async def list_models():
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/{model}")
|
||||
async def retrieve_model(model: str):
|
||||
"""Retrieve a specific model"""
|
||||
try:
|
||||
# Define available models
|
||||
models = {
|
||||
"tts-1": {
|
||||
"id": "tts-1",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
"tts-1-hd": {
|
||||
"id": "tts-1-hd",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
"kokoro": {
|
||||
"id": "kokoro",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
}
|
||||
"tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
"tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
"kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
||||
}
|
||||
|
||||
|
||||
# Check if requested model exists
|
||||
if model not in models:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "model_not_found",
|
||||
"message": f"Model '{model}' not found",
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"})
|
||||
|
||||
# Return the specific model
|
||||
return models[model]
|
||||
except HTTPException:
|
||||
|
@ -512,6 +451,7 @@ async def retrieve_model(model: str):
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices():
|
||||
"""List all available voices for text-to-speech"""
|
||||
|
@ -579,9 +519,7 @@ async def combine_voices(request: Union[str, List[str]]):
|
|||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
||||
|
||||
# Combine voices
|
||||
combined_tensor = await tts_service.combine_voices(voices=voices)
|
||||
|
|
|
@ -115,6 +115,7 @@ class AudioService:
|
|||
audio_chunk: AudioChunk,
|
||||
sample_rate: int,
|
||||
output_format: str,
|
||||
writer: StreamingAudioWriter,
|
||||
speed: float = 1,
|
||||
chunk_text: str = "",
|
||||
is_first_chunk: bool = True,
|
||||
|
@ -153,28 +154,30 @@ class AudioService:
|
|||
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
||||
|
||||
# Get or create format-specific writer
|
||||
writer_key = f"{output_format}_{sample_rate}"
|
||||
"""writer_key = f"{output_format}_{sample_rate}"
|
||||
if is_first_chunk or writer_key not in AudioService._writers:
|
||||
AudioService._writers[writer_key] = StreamingAudioWriter(
|
||||
output_format, sample_rate
|
||||
)
|
||||
|
||||
writer = AudioService._writers[writer_key]
|
||||
writer = AudioService._writers[writer_key]"""
|
||||
|
||||
# Write audio data first
|
||||
if len(audio_chunk.audio) > 0:
|
||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||
|
||||
|
||||
|
||||
# Then finalize if this is the last chunk
|
||||
if is_last_chunk:
|
||||
final_data = writer.write_chunk(finalize=True)
|
||||
del AudioService._writers[writer_key]
|
||||
|
||||
if final_data:
|
||||
audio_chunk.output=final_data
|
||||
return audio_chunk
|
||||
|
||||
|
||||
if chunk_data:
|
||||
audio_chunk.output=chunk_data
|
||||
audio_chunk.output=chunk_data
|
||||
return audio_chunk
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from .streaming_audio_writer import StreamingAudioWriter
|
||||
import torch
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
@ -50,6 +51,7 @@ class TTSService:
|
|||
voice_name: str,
|
||||
voice_path: str,
|
||||
speed: float,
|
||||
writer: StreamingAudioWriter,
|
||||
output_format: Optional[str] = None,
|
||||
is_first: bool = False,
|
||||
is_last: bool = False,
|
||||
|
@ -64,12 +66,13 @@ class TTSService:
|
|||
if is_last:
|
||||
# Skip format conversion for raw audio mode
|
||||
if not output_format:
|
||||
yield AudioChunk(np.array([], dtype=np.int16),output=b'')
|
||||
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
||||
return
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
"",
|
||||
is_first_chunk=False,
|
||||
|
@ -88,7 +91,7 @@ class TTSService:
|
|||
|
||||
# Generate audio using pre-warmed model
|
||||
if isinstance(backend, KokoroV1):
|
||||
chunk_index=0
|
||||
chunk_index = 0
|
||||
# For Kokoro V1, pass text and voice info with lang_code
|
||||
async for chunk_data in self.model_manager.generate(
|
||||
chunk_text,
|
||||
|
@ -104,6 +107,7 @@ class TTSService:
|
|||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
is_first_chunk=is_first and chunk_index == 0,
|
||||
|
@ -114,22 +118,13 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
chunk_data = AudioService.trim_audio(chunk_data,
|
||||
chunk_text,
|
||||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
||||
yield chunk_data
|
||||
chunk_index+=1
|
||||
chunk_index += 1
|
||||
else:
|
||||
|
||||
# For legacy backends, load voice tensor
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
)
|
||||
chunk_data = await self.model_manager.generate(
|
||||
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
|
||||
)
|
||||
voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
||||
chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps)
|
||||
|
||||
if chunk_data.audio is None:
|
||||
logger.error("Model generated None for audio chunk")
|
||||
|
@ -146,6 +141,7 @@ class TTSService:
|
|||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
is_first_chunk=is_first,
|
||||
|
@ -156,11 +152,7 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
trimmed = AudioService.trim_audio(chunk_data,
|
||||
chunk_text,
|
||||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
||||
yield trimmed
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
|
@ -169,7 +161,7 @@ class TTSService:
|
|||
# Check if the path is None and raise a ValueError if it is not
|
||||
if not path:
|
||||
raise ValueError(f"Voice not found at path: {path}")
|
||||
|
||||
|
||||
logger.debug(f"Loading voice tensor from path: {path}")
|
||||
return torch.load(path, map_location="cpu") * weight
|
||||
|
||||
|
@ -191,10 +183,8 @@ class TTSService:
|
|||
|
||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||
if len(split_voice) == 1:
|
||||
|
||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
||||
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
|
@ -202,8 +192,8 @@ class TTSService:
|
|||
return voice, path
|
||||
|
||||
total_weight = 0
|
||||
|
||||
for voice_index in range(0,len(split_voice),2):
|
||||
|
||||
for voice_index in range(0, len(split_voice), 2):
|
||||
voice_object = split_voice[voice_index]
|
||||
|
||||
if "(" in voice_object and ")" in voice_object:
|
||||
|
@ -215,7 +205,7 @@ class TTSService:
|
|||
|
||||
total_weight += voice_weight
|
||||
split_voice[voice_index] = (voice_name, voice_weight)
|
||||
|
||||
|
||||
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||
if settings.voice_weight_normalization == False:
|
||||
total_weight = 1
|
||||
|
@ -225,9 +215,9 @@ class TTSService:
|
|||
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
||||
|
||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||
for operation_index in range(1,len(split_voice) - 1, 2):
|
||||
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||
# Get the voice path of the voice 1 index ahead of the operator
|
||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
|
||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0])
|
||||
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
||||
|
||||
# Either add or subtract the voice from the current combined voice
|
||||
|
@ -235,7 +225,7 @@ class TTSService:
|
|||
combined_tensor += voice_tensor
|
||||
else:
|
||||
combined_tensor -= voice_tensor
|
||||
|
||||
|
||||
# Save the new combined voice so it can be loaded latter
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
|
@ -250,6 +240,7 @@ class TTSService:
|
|||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
writer: StreamingAudioWriter,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
lang_code: Optional[str] = None,
|
||||
|
@ -259,7 +250,7 @@ class TTSService:
|
|||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
current_offset=0.0
|
||||
current_offset = 0.0
|
||||
try:
|
||||
# Get backend
|
||||
backend = self.model_manager.get_backend()
|
||||
|
@ -270,13 +261,10 @@ class TTSService:
|
|||
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
|
||||
async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
|
@ -285,6 +273,7 @@ class TTSService:
|
|||
voice_name, # Pass voice name
|
||||
voice_path, # Pass voice path
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
|
@ -294,23 +283,19 @@ class TTSService:
|
|||
):
|
||||
if chunk_data.word_timestamps is not None:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp.start_time+=current_offset
|
||||
timestamp.end_time+=current_offset
|
||||
|
||||
current_offset+=len(chunk_data.audio) / 24000
|
||||
|
||||
timestamp.start_time += current_offset
|
||||
timestamp.end_time += current_offset
|
||||
|
||||
current_offset += len(chunk_data.audio) / 24000
|
||||
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||
)
|
||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||
chunk_index += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||
)
|
||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
||||
continue
|
||||
|
||||
# Only finalize if we successfully processed at least one chunk
|
||||
|
@ -323,6 +308,7 @@ class TTSService:
|
|||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
|
@ -337,32 +323,30 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
writer: StreamingAudioWriter,
|
||||
speed: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
lang_code: Optional[str] = None,
|
||||
) -> AudioChunk:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
audio_data_chunks=[]
|
||||
|
||||
audio_data_chunks = []
|
||||
|
||||
try:
|
||||
async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None):
|
||||
if len(audio_stream_data.audio) > 0:
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
|
||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||
"""Combine multiple voices.
|
||||
|
@ -406,15 +390,11 @@ class TTSService:
|
|||
result = None
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||
)
|
||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
|
||||
|
||||
try:
|
||||
# Use backend's pipeline management
|
||||
for r in backend._get_pipeline(
|
||||
pipeline_lang_code
|
||||
).generate_from_tokens(
|
||||
for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens(
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
|
@ -432,9 +412,7 @@ class TTSService:
|
|||
processing_time = time.time() - start_time
|
||||
return result.audio.numpy(), processing_time
|
||||
else:
|
||||
raise ValueError(
|
||||
"Phoneme generation only supported with Kokoro V1 backend"
|
||||
)
|
||||
raise ValueError("Phoneme generation only supported with Kokoro V1 backend")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
|
||||
from api.src.services.audio import AudioNormalizer, AudioService
|
||||
from api.src.inference.base import AudioChunk
|
||||
|
||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
|
@ -30,9 +30,11 @@ def sample_audio():
|
|||
async def test_convert_to_wav(sample_audio):
|
||||
"""Test converting to WAV format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=False
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -46,8 +48,11 @@ async def test_convert_to_wav(sample_audio):
|
|||
async def test_convert_to_mp3(sample_audio):
|
||||
"""Test converting to MP3 format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("writer", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "mp3"
|
||||
AudioChunk(audio_data), sample_rate, "mp3", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -60,8 +65,11 @@ async def test_convert_to_mp3(sample_audio):
|
|||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "opus"
|
||||
AudioChunk(audio_data), sample_rate, "opus",writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -74,8 +82,11 @@ async def test_convert_to_opus(sample_audio):
|
|||
async def test_convert_to_flac(sample_audio):
|
||||
"""Test converting to FLAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "flac"
|
||||
AudioChunk(audio_data), sample_rate, "flac", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -88,8 +99,11 @@ async def test_convert_to_flac(sample_audio):
|
|||
async def test_convert_to_aac(sample_audio):
|
||||
"""Test converting to M4A format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "aac"
|
||||
AudioChunk(audio_data), sample_rate, "aac", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -102,8 +116,11 @@ async def test_convert_to_aac(sample_audio):
|
|||
async def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "pcm"
|
||||
AudioChunk(audio_data), sample_rate, "pcm", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -115,19 +132,25 @@ async def test_convert_to_pcm(sample_audio):
|
|||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
"""Test that converting to an invalid format raises an error"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
||||
StreamingAudioWriter("invalid", sample_rate=24000)
|
||||
|
||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_wav(sample_audio):
|
||||
"""Test that WAV output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True
|
||||
AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -138,10 +161,13 @@ async def test_normalization_wav(sample_audio):
|
|||
async def test_normalization_pcm(sample_audio):
|
||||
"""Test that PCM output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "pcm"
|
||||
AudioChunk(large_audio), sample_rate, "pcm", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -153,8 +179,11 @@ async def test_invalid_audio_data():
|
|||
"""Test handling of invalid audio data"""
|
||||
invalid_audio = np.array([]) # Empty array
|
||||
sample_rate = 24000
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -164,8 +193,11 @@ async def test_different_sample_rates(sample_audio):
|
|||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), rate, "wav", is_first_chunk=True
|
||||
AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -176,15 +208,21 @@ async def test_different_sample_rates(sample_audio):
|
|||
async def test_buffer_position_after_conversion(sample_audio):
|
||||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
# Write and finalize in one step for first conversion
|
||||
audio_chunk1 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk1.output, bytes)
|
||||
assert isinstance(audio_chunk1, AudioChunk)
|
||||
# Convert again to ensure buffer was properly reset
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
audio_chunk2 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk2.output, bytes)
|
||||
assert isinstance(audio_chunk2, AudioChunk)
|
||||
|
|
Loading…
Add table
Reference in a new issue