diff --git a/Test Threads.py b/Test Threads.py new file mode 100644 index 0000000..1b5f6c1 --- /dev/null +++ b/Test Threads.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Compatible with both Windows and Linux +""" +Kokoro TTS Race Condition Test + +This script creates multiple concurrent requests to a Kokoro TTS service +to reproduce a race condition where audio outputs don't match the requested text. +Each thread generates a simple numbered sentence, which should make mismatches +easy to identify through listening. + +To run: +python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880 +""" + +import argparse +import base64 +import concurrent.futures +import json +import os +import requests +import time +import wave +import sys +from pathlib import Path + + +def setup_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions") + parser.add_argument("--url", default="http://localhost:8880", + help="Base URL of the Kokoro TTS service") + parser.add_argument("--threads", type=int, default=8, + help="Number of concurrent threads to use") + parser.add_argument("--iterations", type=int, default=5, + help="Number of iterations per thread") + parser.add_argument("--voice", default="af_heart", + help="Voice to use for TTS") + parser.add_argument("--output-dir", default="./tts_test_output", + help="Directory to save output files") + parser.add_argument("--debug", action="store_true", + help="Enable debug logging") + return parser.parse_args() + + +def generate_test_sentence(thread_id, iteration): + """Generate a simple test sentence with numbers to make mismatches easily identifiable""" + return f"This is test sentence number {thread_id}-{iteration}. " \ + f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}." + + +def log_message(message, debug=False, is_error=False): + """Log messages with timestamps""" + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + prefix = "[ERROR]" if is_error else "[INFO]" + if is_error or debug: + print(f"{prefix} {timestamp} - {message}") + sys.stdout.flush() # Ensure logs are visible in Docker output + + +def request_tts(url, test_id, text, voice, output_dir, debug=False): + """Request TTS from the Kokoro API and save the WAV output""" + start_time = time.time() + output_file = os.path.join(output_dir, f"test_{test_id}.wav") + text_file = os.path.join(output_dir, f"test_{test_id}.txt") + + # Log output paths for debugging + log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug) + log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug) + + # Save the text for later comparison + try: + with open(text_file, "w") as f: + f.write(text) + log_message(f"Thread {test_id}: Successfully saved text file", debug) + except Exception as e: + log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True) + + # Make the TTS request + try: + log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug) + + response = requests.post( + f"{url}/v1/audio/speech", + json={ + "model": "kokoro", + "input": text, + "voice": voice, + "response_format": "wav" + }, + headers={"Accept": "audio/wav"}, + timeout=60 # Increase timeout to 60 seconds + ) + + log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug) + log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug) + log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug) + + if response.status_code != 200: + log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True) + return False + + # Check if we got valid audio data + if len(response.content) < 100: # Sanity check - WAV files should be larger than this + log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True) + log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True) + return False + + # Save the audio output with explicit error handling + try: + with open(output_file, "wb") as f: + bytes_written = f.write(response.content) + log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug) + + # Verify the WAV file exists and has content + if os.path.exists(output_file): + file_size = os.path.getsize(output_file) + log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug) + + # Validate WAV file by reading its headers + try: + with wave.open(output_file, 'rb') as wav_file: + channels = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + framerate = wav_file.getframerate() + frames = wav_file.getnframes() + log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, " + f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug) + except Exception as wav_error: + log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True) + else: + log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True) + except Exception as save_error: + log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True) + return False + + end_time = time.time() + log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug) + return True + + except requests.exceptions.Timeout: + log_message(f"Thread {test_id}: Request timed out", debug, is_error=True) + return False + except Exception as e: + log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True) + return False + + +def worker_task(thread_id, args): + """Worker task for each thread""" + for i in range(args.iterations): + iteration = i + 1 + test_id = f"{thread_id:02d}_{iteration:02d}" + text = generate_test_sentence(thread_id, iteration) + success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug) + + if not success: + log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True) + + # Small delay between iterations to avoid overwhelming the API + time.sleep(0.1) + + +def run_test(args): + """Run the test with the specified parameters""" + # Ensure output directory exists and check permissions + os.makedirs(args.output_dir, exist_ok=True) + + # Test write access to the output directory + test_file = os.path.join(args.output_dir, "write_test.txt") + try: + with open(test_file, "w") as f: + f.write("Testing write access\n") + os.remove(test_file) + log_message(f"Successfully verified write access to output directory: {args.output_dir}") + except Exception as e: + log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True) + log_message(f"Current directory: {os.getcwd()}", is_error=True) + log_message(f"Directory contents: {os.listdir('.')}", is_error=True) + + # Test connection to Kokoro TTS service + try: + response = requests.get(f"{args.url}/health", timeout=5) + if response.status_code == 200: + log_message(f"Successfully connected to Kokoro TTS service at {args.url}") + else: + log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True) + except Exception as e: + log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True) + + # Record start time + start_time = time.time() + log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread") + + # Create and start worker threads + with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor: + futures = [] + for thread_id in range(1, args.threads + 1): + futures.append(executor.submit(worker_task, thread_id, args)) + + # Wait for all tasks to complete + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True) + + # Record end time and print summary + end_time = time.time() + total_time = end_time - start_time + total_requests = args.threads * args.iterations + log_message(f"Test completed in {total_time:.2f} seconds") + log_message(f"Total requests: {total_requests}") + log_message(f"Average time per request: {total_time / total_requests:.2f} seconds") + log_message(f"Requests per second: {total_requests / total_time:.2f}") + log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}") + log_message("To verify, listen to the audio files and check if they match the text files") + log_message("If you hear audio describing a different test number than the filename, you've found a race condition") + + +def analyze_audio_files(output_dir): + """Provide summary of the generated audio files""" + # Look for both WAV and TXT files + wav_files = list(Path(output_dir).glob("*.wav")) + txt_files = list(Path(output_dir).glob("*.txt")) + + log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files") + + if len(wav_files) == 0: + log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True) + log_message("Check the connection to the TTS service and the response status codes above.", is_error=True) + + file_stats = [] + for wav_path in wav_files: + try: + with wave.open(str(wav_path), 'rb') as wav_file: + frames = wav_file.getnframes() + rate = wav_file.getframerate() + duration = frames / rate + + # Get corresponding text + text_path = wav_path.with_suffix('.txt') + if text_path.exists(): + with open(text_path, 'r') as text_file: + text = text_file.read().strip() + else: + text = "N/A" + + file_stats.append({ + 'filename': wav_path.name, + 'duration': duration, + 'text': text + }) + except Exception as e: + log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True) + + # Print summary table + if file_stats: + log_message("\nAudio File Summary:") + log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}") + log_message("-" * 92) + for stat in file_stats: + log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}") + + # List missing WAV files where text files exist + missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files) + if missing_wavs: + log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True) + for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability + log_message(f" - {stem}.txt (no WAV file)", is_error=True) + if len(missing_wavs) > 10: + log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True) + + +if __name__ == "__main__": + args = setup_args() + run_test(args) + analyze_audio_files(args.output_dir) + + log_message("\nNext Steps:") + log_message("1. Listen to the generated audio files") + log_message("2. Verify if each audio correctly says its ID number") + log_message("3. Check for any mismatches between the audio content and the text files") + log_message("4. If mismatches are found, you've successfully reproduced the race condition") \ No newline at end of file diff --git a/api/src/routers/development.py b/api/src/routers/development.py index ec5596e..6a0de60 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -120,7 +120,7 @@ async def generate_from_phonemes( except Exception as e: logger.error(f"Error in audio generation: {str(e)}") # Clean up writer on error - writer.write_chunk(finalize=True) + writer.close() # Re-raise the original exception raise @@ -180,10 +180,11 @@ async def create_captioned_speech( "pcm": "audio/pcm", }.get(request.response_format, f"audio/{request.response_format}") + writer = StreamingAudioWriter(request.response_format, sample_rate=24000) # Check if streaming is requested (default for OpenAI client) if request.stream: # Create generator but don't start it yet - generator = stream_audio_chunks(tts_service, request, client_request) + generator = stream_audio_chunks(tts_service, request, client_request, writer) # If download link requested, wrap generator with temp file writer if request.return_download_link: @@ -235,6 +236,7 @@ async def create_captioned_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() # Stream with temp file writing return JSONStreamingResponse( @@ -266,6 +268,7 @@ async def create_captioned_speech( except Exception as e: logger.error(f"Error in single output streaming: {e}") + writer.close() raise # Standard streaming without download link @@ -284,6 +287,7 @@ async def create_captioned_speech( audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, return_timestamps=request.return_timestamps, normalization_options=request.normalization_options, @@ -292,9 +296,8 @@ async def create_captioned_speech( audio_data = await AudioService.convert_audio( audio_data, - 24000, request.response_format, - is_first_chunk=True, + writer, is_last_chunk=False, trim_audio=False, ) @@ -302,9 +305,8 @@ async def create_captioned_speech( # Convert to requested format with proper finalization final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), - 24000, request.response_format, - is_first_chunk=False, + writer, is_last_chunk=True, ) output=audio_data.output + final.output @@ -312,6 +314,9 @@ async def create_captioned_speech( base64_output= base64.b64encode(output).decode("utf-8") content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump() + + writer.close() + return JSONResponse( content=content, media_type="application/json", @@ -324,6 +329,12 @@ async def create_captioned_speech( except ValueError as e: # Handle validation errors logger.warning(f"Invalid request: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=400, detail={ @@ -335,6 +346,12 @@ async def create_captioned_speech( except RuntimeError as e: # Handle runtime/processing errors logger.error(f"Processing error: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -346,6 +363,12 @@ async def create_captioned_speech( except Exception as e: # Handle unexpected errors logger.error(f"Unexpected error in captioned speech generation: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 10e32bf..742c216 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -10,11 +10,12 @@ from urllib import response import aiofiles import numpy as np +from ..services.streaming_audio_writer import StreamingAudioWriter import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import FileResponse, StreamingResponse from loguru import logger -from structures.schemas import CaptionedSpeechRequest +from ..structures.schemas import CaptionedSpeechRequest from ..core.config import settings from ..inference.base import AudioChunk @@ -79,9 +80,7 @@ def get_model_name(model: str) -> str: return base_name + ".pth" -async def process_and_validate_voices( - voice_input: Union[str, List[str]], tts_service: TTSService -) -> str: +async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str: """Process voice input, handling both string and list formats Returns: @@ -90,73 +89,59 @@ async def process_and_validate_voices( voices = [] # Convert input to list of voices if isinstance(voice_input, str): - voice_input=voice_input.replace(" ","").strip() + voice_input = voice_input.replace(" ", "").strip() if voice_input[-1] in "+-" or voice_input[0] in "+-": - raise ValueError( - f"Voice combination contains empty combine items" - ) + raise ValueError(f"Voice combination contains empty combine items") if re.search(r"[+-]{2,}", voice_input) is not None: - raise ValueError( - f"Voice combination contains empty combine items" - ) + raise ValueError(f"Voice combination contains empty combine items") voices = re.split(r"([-+])", voice_input) else: - voices = [[item,"+"] for item in voice_input][:-1] - + voices = [[item, "+"] for item in voice_input][:-1] + available_voices = await tts_service.list_voices() - for voice_index in range(0,len(voices), 2): - + for voice_index in range(0, len(voices), 2): mapped_voice = voices[voice_index].split("(") mapped_voice = list(map(str.strip, mapped_voice)) if len(mapped_voice) > 2: - raise ValueError( - f"Voice '{voices[voice_index]}' contains too many weight items" - ) - + raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items") + if mapped_voice.count(")") > 1: - raise ValueError( - f"Voice '{voices[voice_index]}' contains too many weight items" - ) - + raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items") + mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0]) if mapped_voice[0] not in available_voices: - raise ValueError( - f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}" - ) + raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}") voices[voice_index] = "(".join(mapped_voice) - + return "".join(voices) -async def stream_audio_chunks( - tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request -) -> AsyncGenerator[AudioChunk, None]: + +async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]: """Stream audio chunks as they're generated with client disconnect handling""" voice_name = await process_and_validate_voices(request.voice, tts_service) - unique_properties={ - "return_timestamps":False - } + unique_properties = {"return_timestamps": False} if hasattr(request, "return_timestamps"): - unique_properties["return_timestamps"]=request.return_timestamps - + unique_properties["return_timestamps"] = request.return_timestamps + try: logger.info(f"Starting audio generation with lang_code: {request.lang_code}") async for chunk_data in tts_service.generate_audio_stream( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, output_format=request.response_format, lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(), normalization_options=request.normalization_options, return_timestamps=unique_properties["return_timestamps"], ): - # Check if client is still connected is_disconnected = client_request.is_disconnected if callable(is_disconnected): @@ -174,7 +159,6 @@ async def stream_audio_chunks( @router.post("/audio/speech") async def create_speech( - request: OpenAISpeechRequest, client_request: Request, x_raw_response: str = Header(None, alias="x-raw-response"), @@ -206,10 +190,12 @@ async def create_speech( "pcm": "audio/pcm", }.get(request.response_format, f"audio/{request.response_format}") + writer = StreamingAudioWriter(request.response_format, sample_rate=24000) + # Check if streaming is requested (default for OpenAI client) if request.stream: # Create generator but don't start it yet - generator = stream_audio_chunks(tts_service, request, client_request) + generator = stream_audio_chunks(tts_service, request, client_request, writer) # If download link requested, wrap generator with temp file writer if request.return_download_link: @@ -239,9 +225,9 @@ async def create_speech( async for chunk_data in generator: if chunk_data.output: # Skip empty chunks await temp_writer.write(chunk_data.output) - #if return_json: + # if return_json: # yield chunk, chunk_data - #else: + # else: yield chunk_data.output # Finalize the temp file @@ -254,11 +240,10 @@ async def create_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() # Stream with temp file writing - return StreamingResponse( - dual_output(), media_type=content_type, headers=headers - ) + return StreamingResponse(dual_output(), media_type=content_type, headers=headers) async def single_output(): try: @@ -268,8 +253,9 @@ async def create_speech( yield chunk_data.output except Exception as e: logger.error(f"Error in single output streaming: {e}") + writer.close() raise - + # Standard streaming without download link return StreamingResponse( single_output(), @@ -283,40 +269,34 @@ async def create_speech( ) else: headers = { - "Content-Disposition": f"attachment; filename=speech.{request.response_format}", - "Cache-Control": "no-cache", # Prevent caching - } + "Content-Disposition": f"attachment; filename=speech.{request.response_format}", + "Cache-Control": "no-cache", # Prevent caching + } # Generate complete audio using public interface audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, + writer=writer, speed=request.speed, normalization_options=request.normalization_options, lang_code=request.lang_code, ) - audio_data = await AudioService.convert_audio( - audio_data, - 24000, - request.response_format, - is_first_chunk=True, - is_last_chunk=False, - trim_audio=False - ) - + audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False) + # Convert to requested format with proper finalization final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), - 24000, request.response_format, - is_first_chunk=False, + writer, is_last_chunk=True, ) - output=audio_data.output + final.output + output = audio_data.output + final.output if request.return_download_link: from ..services.temp_manager import TempFileWriter + # Use download_format if specified, otherwise use response_format output_format = request.download_format or request.response_format temp_writer = TempFileWriter(output_format) @@ -341,6 +321,7 @@ async def create_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() return Response( content=output, @@ -351,6 +332,12 @@ async def create_speech( except ValueError as e: # Handle validation errors logger.warning(f"Invalid request: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=400, detail={ @@ -362,6 +349,12 @@ async def create_speech( except RuntimeError as e: # Handle runtime/processing errors logger.error(f"Processing error: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -373,6 +366,12 @@ async def create_speech( except Exception as e: # Handle unexpected errors logger.error(f"Unexpected error in speech generation: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -381,6 +380,7 @@ async def create_speech( "type": "server_error", }, ) + @router.get("/download/{filename}") @@ -390,9 +390,7 @@ async def download_audio_file(filename: str): from ..core.paths import _find_file, get_content_type # Search for file in temp directory - file_path = await _find_file( - filename=filename, search_paths=[settings.temp_file_dir] - ) + file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir]) # Get content type from path helper content_type = await get_content_type(file_path) @@ -425,30 +423,12 @@ async def list_models(): try: # Create standard model list models = [ - { - "id": "tts-1", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - { - "id": "tts-1-hd", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - { - "id": "kokoro", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - } + {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, ] - - return { - "object": "list", - "data": models - } + + return {"object": "list", "data": models} except Exception as e: logger.error(f"Error listing models: {str(e)}") raise HTTPException( @@ -460,43 +440,22 @@ async def list_models(): }, ) + @router.get("/models/{model}") async def retrieve_model(model: str): """Retrieve a specific model""" try: # Define available models models = { - "tts-1": { - "id": "tts-1", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - "tts-1-hd": { - "id": "tts-1-hd", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - }, - "kokoro": { - "id": "kokoro", - "object": "model", - "created": 1686935002, - "owned_by": "kokoro" - } + "tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + "tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, + "kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"}, } - + # Check if requested model exists if model not in models: - raise HTTPException( - status_code=404, - detail={ - "error": "model_not_found", - "message": f"Model '{model}' not found", - "type": "invalid_request_error" - } - ) - + raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"}) + # Return the specific model return models[model] except HTTPException: @@ -512,6 +471,7 @@ async def retrieve_model(model: str): }, ) + @router.get("/audio/voices") async def list_voices(): """List all available voices for text-to-speech""" @@ -579,9 +539,7 @@ async def combine_voices(request: Union[str, List[str]]): available_voices = await tts_service.list_voices() for voice in voices: if voice not in available_voices: - raise ValueError( - f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}" - ) + raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}") # Combine voices combined_tensor = await tts_service.combine_voices(voices=voices) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index d1b412e..5e344ec 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -108,16 +108,13 @@ class AudioService: }, } - _writers = {} - @staticmethod async def convert_audio( audio_chunk: AudioChunk, - sample_rate: int, output_format: str, + writer: StreamingAudioWriter, speed: float = 1, chunk_text: str = "", - is_first_chunk: bool = True, is_last_chunk: bool = False, trim_audio: bool = True, normalizer: AudioNormalizer = None, @@ -126,12 +123,12 @@ class AudioService: Args: audio_data: Numpy array of audio samples - sample_rate: Sample rate of the audio output_format: Target format (wav, mp3, ogg, pcm) + writer: The StreamingAudioWriter to use speed: The speaking speed of the voice chunk_text: The text sent to the model to generate the resulting speech - is_first_chunk: Whether this is the first chunk is_last_chunk: Whether this is the last chunk + trim_audio: Whether audio should be trimmed normalizer: Optional AudioNormalizer instance for consistent normalization Returns: @@ -152,29 +149,22 @@ class AudioService: if trim_audio == True: audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer) - # Get or create format-specific writer - writer_key = f"{output_format}_{sample_rate}" - if is_first_chunk or writer_key not in AudioService._writers: - AudioService._writers[writer_key] = StreamingAudioWriter( - output_format, sample_rate - ) - - writer = AudioService._writers[writer_key] - # Write audio data first if len(audio_chunk.audio) > 0: chunk_data = writer.write_chunk(audio_chunk.audio) + + # Then finalize if this is the last chunk if is_last_chunk: final_data = writer.write_chunk(finalize=True) - del AudioService._writers[writer_key] + if final_data: audio_chunk.output=final_data return audio_chunk - + if chunk_data: - audio_chunk.output=chunk_data + audio_chunk.output=chunk_data return audio_chunk except Exception as e: diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 71dcd32..763c5eb 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -22,7 +22,7 @@ class StreamingAudioWriter: codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"} # Format-specific setup - if self.format in ["wav", "opus","flac","mp3","aac","pcm"]: + if self.format in ["wav","flac","mp3","pcm","aac","opus"]: if self.format != "pcm": self.output_buffer = BytesIO() self.container = av.open(self.output_buffer, mode="w", format=self.format) @@ -31,6 +31,13 @@ class StreamingAudioWriter: else: raise ValueError(f"Unsupported format: {format}") + def close(self): + if hasattr(self, "container"): + self.container.close() + + if hasattr(self, "output_buffer"): + self.output_buffer.close() + def write_chunk( self, audio_data: Optional[np.ndarray] = None, finalize: bool = False ) -> bytes: @@ -48,7 +55,7 @@ class StreamingAudioWriter: self.container.mux(packet) data=self.output_buffer.getvalue() - self.container.close() + self.close() return data if audio_data is None or len(audio_data) == 0: diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index cb9f655..f740a29 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -8,6 +8,7 @@ import time from typing import AsyncGenerator, List, Optional, Tuple, Union import numpy as np +from .streaming_audio_writer import StreamingAudioWriter import torch from kokoro import KPipeline from loguru import logger @@ -50,6 +51,7 @@ class TTSService: voice_name: str, voice_path: str, speed: float, + writer: StreamingAudioWriter, output_format: Optional[str] = None, is_first: bool = False, is_last: bool = False, @@ -64,15 +66,14 @@ class TTSService: if is_last: # Skip format conversion for raw audio mode if not output_format: - yield AudioChunk(np.array([], dtype=np.int16),output=b'') + yield AudioChunk(np.array([], dtype=np.int16), output=b"") return chunk_data = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking - 24000, output_format, + writer, speed, "", - is_first_chunk=False, normalizer=normalizer, is_last_chunk=True, ) @@ -88,7 +89,7 @@ class TTSService: # Generate audio using pre-warmed model if isinstance(backend, KokoroV1): - chunk_index=0 + chunk_index = 0 # For Kokoro V1, pass text and voice info with lang_code async for chunk_data in self.model_manager.generate( chunk_text, @@ -102,11 +103,10 @@ class TTSService: try: chunk_data = await AudioService.convert_audio( chunk_data, - 24000, output_format, + writer, speed, chunk_text, - is_first_chunk=is_first and chunk_index == 0, is_last_chunk=is_last, normalizer=normalizer, ) @@ -114,22 +114,13 @@ class TTSService: except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - chunk_data = AudioService.trim_audio(chunk_data, - chunk_text, - speed, - is_last, - normalizer) + chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) yield chunk_data - chunk_index+=1 + chunk_index += 1 else: - # For legacy backends, load voice tensor - voice_tensor = await self._voice_manager.load_voice( - voice_name, device=backend.device - ) - chunk_data = await self.model_manager.generate( - tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps - ) + voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device) + chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps) if chunk_data.audio is None: logger.error("Model generated None for audio chunk") @@ -144,11 +135,10 @@ class TTSService: try: chunk_data = await AudioService.convert_audio( chunk_data, - 24000, output_format, + writer, speed, chunk_text, - is_first_chunk=is_first, normalizer=normalizer, is_last_chunk=is_last, ) @@ -156,11 +146,7 @@ class TTSService: except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - trimmed = AudioService.trim_audio(chunk_data, - chunk_text, - speed, - is_last, - normalizer) + trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) yield trimmed except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") @@ -169,7 +155,7 @@ class TTSService: # Check if the path is None and raise a ValueError if it is not if not path: raise ValueError(f"Voice not found at path: {path}") - + logger.debug(f"Loading voice tensor from path: {path}") return torch.load(path, map_location="cpu") * weight @@ -191,10 +177,8 @@ class TTSService: # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it if len(split_voice) == 1: - # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True: - path = await self._voice_manager.get_voice_path(voice) if not path: raise RuntimeError(f"Voice not found: {voice}") @@ -202,8 +186,8 @@ class TTSService: return voice, path total_weight = 0 - - for voice_index in range(0,len(split_voice),2): + + for voice_index in range(0, len(split_voice), 2): voice_object = split_voice[voice_index] if "(" in voice_object and ")" in voice_object: @@ -215,7 +199,7 @@ class TTSService: total_weight += voice_weight split_voice[voice_index] = (voice_name, voice_weight) - + # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1 if settings.voice_weight_normalization == False: total_weight = 1 @@ -225,9 +209,9 @@ class TTSService: combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight) # Loop through each + or - in split_voice so they can be applied to combined voice - for operation_index in range(1,len(split_voice) - 1, 2): + for operation_index in range(1, len(split_voice) - 1, 2): # Get the voice path of the voice 1 index ahead of the operator - path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0]) + path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0]) voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight) # Either add or subtract the voice from the current combined voice @@ -235,7 +219,7 @@ class TTSService: combined_tensor += voice_tensor else: combined_tensor -= voice_tensor - + # Save the new combined voice so it can be loaded latter temp_dir = tempfile.gettempdir() combined_path = os.path.join(temp_dir, f"{voice}.pt") @@ -250,6 +234,7 @@ class TTSService: self, text: str, voice: str, + writer: StreamingAudioWriter, speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, @@ -259,7 +244,7 @@ class TTSService: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() chunk_index = 0 - current_offset=0.0 + current_offset = 0.0 try: # Get backend backend = self.model_manager.get_backend() @@ -270,13 +255,10 @@ class TTSService: # Use provided lang_code or determine from voice name pipeline_lang_code = lang_code if lang_code else voice[:1].lower() - logger.info( - f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" - ) - - + logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream") + # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options): + async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options): try: # Process audio for chunk async for chunk_data in self._process_chunk( @@ -285,6 +267,7 @@ class TTSService: voice_name, # Pass voice name voice_path, # Pass voice path speed, + writer, output_format, is_first=(chunk_index == 0), is_last=False, # We'll update the last chunk later @@ -294,23 +277,19 @@ class TTSService: ): if chunk_data.word_timestamps is not None: for timestamp in chunk_data.word_timestamps: - timestamp.start_time+=current_offset - timestamp.end_time+=current_offset - - current_offset+=len(chunk_data.audio) / 24000 - + timestamp.start_time += current_offset + timestamp.end_time += current_offset + + current_offset += len(chunk_data.audio) / 24000 + if chunk_data.output is not None: yield chunk_data - + else: - logger.warning( - f"No audio generated for chunk: '{chunk_text[:100]}...'" - ) + logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") chunk_index += 1 except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" - ) + logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") continue # Only finalize if we successfully processed at least one chunk @@ -323,6 +302,7 @@ class TTSService: voice_name, voice_path, speed, + writer, output_format, is_first=False, is_last=True, # Signal this is the last chunk @@ -337,32 +317,30 @@ class TTSService: except Exception as e: logger.error(f"Error in phoneme audio generation: {str(e)}") raise e - async def generate_audio( self, text: str, voice: str, + writer: StreamingAudioWriter, speed: float = 1.0, return_timestamps: bool = False, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), lang_code: Optional[str] = None, ) -> AudioChunk: """Generate complete audio for text using streaming internally.""" - audio_data_chunks=[] - + audio_data_chunks = [] + try: - async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None): + async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None): if len(audio_stream_data.audio) > 0: audio_data_chunks.append(audio_stream_data) - - combined_audio_data=AudioChunk.combine(audio_data_chunks) + combined_audio_data = AudioChunk.combine(audio_data_chunks) return combined_audio_data except Exception as e: logger.error(f"Error in audio generation: {str(e)}") raise - async def combine_voices(self, voices: List[str]) -> torch.Tensor: """Combine multiple voices. @@ -406,15 +384,11 @@ class TTSService: result = None # Use provided lang_code or determine from voice name pipeline_lang_code = lang_code if lang_code else voice[:1].lower() - logger.info( - f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline" - ) + logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline") try: # Use backend's pipeline management - for r in backend._get_pipeline( - pipeline_lang_code - ).generate_from_tokens( + for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens( tokens=phonemes, # Pass raw phonemes string voice=voice_path, speed=speed, @@ -432,9 +406,7 @@ class TTSService: processing_time = time.time() - start_time return result.audio.numpy(), processing_time else: - raise ValueError( - "Phoneme generation only supported with Kokoro V1 backend" - ) + raise ValueError("Phoneme generation only supported with Kokoro V1 backend") except Exception as e: logger.error(f"Error in phoneme audio generation: {str(e)}") diff --git a/api/tests/conftest.py b/api/tests/conftest.py index b8dd761..dee66f9 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -70,16 +70,3 @@ def test_voice(): """Return a test voice name.""" return "voice1" - -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for the test session.""" - import asyncio - - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index ca2d25d..9351454 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -7,7 +7,7 @@ import pytest from api.src.services.audio import AudioNormalizer, AudioService from api.src.inference.base import AudioChunk - +from api.src.services.streaming_audio_writer import StreamingAudioWriter @pytest.fixture(autouse=True) def mock_settings(): """Mock settings for all tests""" @@ -30,10 +30,15 @@ def sample_audio(): async def test_convert_to_wav(sample_audio): """Test converting to WAV format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=False + AudioChunk(audio_data), "wav", writer, is_last_chunk=False ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -46,9 +51,15 @@ async def test_convert_to_wav(sample_audio): async def test_convert_to_mp3(sample_audio): """Test converting to MP3 format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("mp3", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "mp3" + AudioChunk(audio_data), "mp3", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -59,10 +70,17 @@ async def test_convert_to_mp3(sample_audio): @pytest.mark.asyncio async def test_convert_to_opus(sample_audio): """Test converting to Opus format""" + audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("opus", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "opus" + AudioChunk(audio_data), "opus",writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -74,9 +92,15 @@ async def test_convert_to_opus(sample_audio): async def test_convert_to_flac(sample_audio): """Test converting to FLAC format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("flac", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "flac" + AudioChunk(audio_data), "flac", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -88,9 +112,15 @@ async def test_convert_to_flac(sample_audio): async def test_convert_to_aac(sample_audio): """Test converting to M4A format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("aac", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "aac" + AudioChunk(audio_data), "aac", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -102,9 +132,15 @@ async def test_convert_to_aac(sample_audio): async def test_convert_to_pcm(sample_audio): """Test converting to PCM format""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("pcm", sample_rate=24000) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "pcm" + AudioChunk(audio_data), "pcm", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -114,21 +150,27 @@ async def test_convert_to_pcm(sample_audio): @pytest.mark.asyncio async def test_convert_to_invalid_format_raises_error(sample_audio): """Test that converting to an invalid format raises an error""" - audio_data, sample_rate = sample_audio - with pytest.raises(ValueError, match="Format invalid not supported"): - await AudioService.convert_audio(audio_data, sample_rate, "invalid") + #audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="Unsupported format: invalid"): + writer = StreamingAudioWriter("invalid", sample_rate=24000) @pytest.mark.asyncio async def test_normalization_wav(sample_audio): """Test that WAV output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) + # Create audio data outside int16 range large_audio = audio_data * 1e5 # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True + AudioChunk(large_audio), "wav", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -138,10 +180,13 @@ async def test_normalization_wav(sample_audio): async def test_normalization_pcm(sample_audio): """Test that PCM output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("pcm", sample_rate=24000) + # Create audio data outside int16 range large_audio = audio_data * 1e5 audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "pcm" + AudioChunk(large_audio), "pcm", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -153,8 +198,11 @@ async def test_invalid_audio_data(): """Test handling of invalid audio data""" invalid_audio = np.array([]) # Empty array sample_rate = 24000 + + writer = StreamingAudioWriter("wav", sample_rate=24000) + with pytest.raises(ValueError): - await AudioService.convert_audio(invalid_audio, sample_rate, "wav") + await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer) @pytest.mark.asyncio @@ -164,9 +212,15 @@ async def test_different_sample_rates(sample_audio): sample_rates = [8000, 16000, 44100, 48000] for rate in sample_rates: + + writer = StreamingAudioWriter("wav", sample_rate=rate) + audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), rate, "wav", is_first_chunk=True + AudioChunk(audio_data), "wav", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -176,15 +230,21 @@ async def test_different_sample_rates(sample_audio): async def test_buffer_position_after_conversion(sample_audio): """Test that buffer position is reset after writing""" audio_data, sample_rate = sample_audio + + writer = StreamingAudioWriter("wav", sample_rate=24000) + # Write and finalize in one step for first conversion audio_chunk1 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), "wav", writer, is_last_chunk=True ) assert isinstance(audio_chunk1.output, bytes) assert isinstance(audio_chunk1, AudioChunk) # Convert again to ensure buffer was properly reset + + writer = StreamingAudioWriter("wav", sample_rate=24000) + audio_chunk2 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), "wav", writer, is_last_chunk=True ) assert isinstance(audio_chunk2.output, bytes) assert isinstance(audio_chunk2, AudioChunk) diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py index 527cb1f..0f89a6c 100644 --- a/api/tests/test_openai_endpoints.py +++ b/api/tests/test_openai_endpoints.py @@ -4,6 +4,8 @@ import os from typing import AsyncGenerator, Tuple from unittest.mock import AsyncMock, MagicMock, patch +from api.src.services.streaming_audio_writer import StreamingAudioWriter + from api.src.inference.base import AudioChunk import numpy as np import pytest @@ -159,10 +161,14 @@ async def test_stream_audio_chunks_client_disconnect(): speed=1.0, ) + writer = StreamingAudioWriter("mp3", 24000) + chunks = [] - async for chunk in stream_audio_chunks(mock_service, request, mock_request): + async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer): chunks.append(chunk) + writer.close() + assert len(chunks) == 0 # Should stop immediately due to disconnect @@ -483,7 +489,11 @@ async def test_streaming_initialization_error(): speed=1.0, ) + writer = StreamingAudioWriter("mp3", 24000) + with pytest.raises(RuntimeError) as exc: - async for _ in stream_audio_chunks(mock_service, request, MagicMock()): + async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer): pass + + writer.close() assert "Failed to initialize stream" in str(exc.value) diff --git a/pyproject.toml b/pyproject.toml index c6b9833..a2de3dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl", "inflect>=7.5.0", "phonemizer-fork>=3.3.2", - "av>=14.1.0", + "av>=14.2.0", ] [project.optional-dependencies] @@ -52,10 +52,10 @@ cpu = [ "torch==2.6.0", ] test = [ - "pytest==8.0.0", - "pytest-cov==4.1.0", + "pytest==8.3.5", + "pytest-cov==6.0.0", "httpx==0.26.0", - "pytest-asyncio==0.23.5", + "pytest-asyncio==0.25.3", "openai>=1.59.6", "tomli>=2.0.1", ] @@ -106,5 +106,5 @@ packages.find = {where = ["api/src"], namespaces = true} [tool.pytest.ini_options] testpaths = ["api/tests", "ui/tests"] python_files = ["test_*.py"] -addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc" -asyncio_mode = "strict" +addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace" +asyncio_mode = "auto"