diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 96f94ea..6a0de60 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -120,7 +120,7 @@ async def generate_from_phonemes( except Exception as e: logger.error(f"Error in audio generation: {str(e)}") # Clean up writer on error - writer.write_chunk(finalize=True) + writer.close() # Re-raise the original exception raise @@ -236,6 +236,7 @@ async def create_captioned_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() # Stream with temp file writing return JSONStreamingResponse( @@ -267,6 +268,7 @@ async def create_captioned_speech( except Exception as e: logger.error(f"Error in single output streaming: {e}") + writer.close() raise # Standard streaming without download link @@ -294,10 +296,8 @@ async def create_captioned_speech( audio_data = await AudioService.convert_audio( audio_data, - 24000, request.response_format, writer, - is_first_chunk=True, is_last_chunk=False, trim_audio=False, ) @@ -305,10 +305,8 @@ async def create_captioned_speech( # Convert to requested format with proper finalization final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), - 24000, request.response_format, writer, - is_first_chunk=False, is_last_chunk=True, ) output=audio_data.output + final.output @@ -316,6 +314,9 @@ async def create_captioned_speech( base64_output= base64.b64encode(output).decode("utf-8") content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump() + + writer.close() + return JSONResponse( content=content, media_type="application/json", @@ -328,6 +329,12 @@ async def create_captioned_speech( except ValueError as e: # Handle validation errors logger.warning(f"Invalid request: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=400, detail={ @@ -339,6 +346,12 @@ async def create_captioned_speech( except RuntimeError as e: # Handle runtime/processing errors logger.error(f"Processing error: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -350,6 +363,12 @@ async def create_captioned_speech( except Exception as e: # Handle unexpected errors logger.error(f"Unexpected error in captioned speech generation: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 6d37e71..742c216 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -240,6 +240,7 @@ async def create_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() # Stream with temp file writing return StreamingResponse(dual_output(), media_type=content_type, headers=headers) @@ -252,6 +253,7 @@ async def create_speech( yield chunk_data.output except Exception as e: logger.error(f"Error in single output streaming: {e}") + writer.close() raise # Standard streaming without download link @@ -281,15 +283,13 @@ async def create_speech( lang_code=request.lang_code, ) - audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False) + audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False) # Convert to requested format with proper finalization final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), - 24000, request.response_format, writer, - is_first_chunk=False, is_last_chunk=True, ) output = audio_data.output + final.output @@ -321,6 +321,7 @@ async def create_speech( # Ensure temp writer is closed if not temp_writer._finalized: await temp_writer.__aexit__(None, None, None) + writer.close() return Response( content=output, @@ -331,6 +332,12 @@ async def create_speech( except ValueError as e: # Handle validation errors logger.warning(f"Invalid request: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=400, detail={ @@ -342,6 +349,12 @@ async def create_speech( except RuntimeError as e: # Handle runtime/processing errors logger.error(f"Processing error: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -353,6 +366,12 @@ async def create_speech( except Exception as e: # Handle unexpected errors logger.error(f"Unexpected error in speech generation: {str(e)}") + + try: + writer.close() + except: + pass + raise HTTPException( status_code=500, detail={ @@ -361,6 +380,7 @@ async def create_speech( "type": "server_error", }, ) + @router.get("/download/{filename}") diff --git a/api/src/services/audio.py b/api/src/services/audio.py index d0aed80..5e344ec 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -108,17 +108,13 @@ class AudioService: }, } - _writers = {} - @staticmethod async def convert_audio( audio_chunk: AudioChunk, - sample_rate: int, output_format: str, writer: StreamingAudioWriter, speed: float = 1, chunk_text: str = "", - is_first_chunk: bool = True, is_last_chunk: bool = False, trim_audio: bool = True, normalizer: AudioNormalizer = None, @@ -127,12 +123,12 @@ class AudioService: Args: audio_data: Numpy array of audio samples - sample_rate: Sample rate of the audio output_format: Target format (wav, mp3, ogg, pcm) + writer: The StreamingAudioWriter to use speed: The speaking speed of the voice chunk_text: The text sent to the model to generate the resulting speech - is_first_chunk: Whether this is the first chunk is_last_chunk: Whether this is the last chunk + trim_audio: Whether audio should be trimmed normalizer: Optional AudioNormalizer instance for consistent normalization Returns: @@ -153,15 +149,6 @@ class AudioService: if trim_audio == True: audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer) - # Get or create format-specific writer - """writer_key = f"{output_format}_{sample_rate}" - if is_first_chunk or writer_key not in AudioService._writers: - AudioService._writers[writer_key] = StreamingAudioWriter( - output_format, sample_rate - ) - - writer = AudioService._writers[writer_key]""" - # Write audio data first if len(audio_chunk.audio) > 0: chunk_data = writer.write_chunk(audio_chunk.audio) diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 71dcd32..763c5eb 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -22,7 +22,7 @@ class StreamingAudioWriter: codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"} # Format-specific setup - if self.format in ["wav", "opus","flac","mp3","aac","pcm"]: + if self.format in ["wav","flac","mp3","pcm","aac","opus"]: if self.format != "pcm": self.output_buffer = BytesIO() self.container = av.open(self.output_buffer, mode="w", format=self.format) @@ -31,6 +31,13 @@ class StreamingAudioWriter: else: raise ValueError(f"Unsupported format: {format}") + def close(self): + if hasattr(self, "container"): + self.container.close() + + if hasattr(self, "output_buffer"): + self.output_buffer.close() + def write_chunk( self, audio_data: Optional[np.ndarray] = None, finalize: bool = False ) -> bytes: @@ -48,7 +55,7 @@ class StreamingAudioWriter: self.container.mux(packet) data=self.output_buffer.getvalue() - self.container.close() + self.close() return data if audio_data is None or len(audio_data) == 0: diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 1b31eab..f740a29 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -70,12 +70,10 @@ class TTSService: return chunk_data = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking - 24000, output_format, writer, speed, "", - is_first_chunk=False, normalizer=normalizer, is_last_chunk=True, ) @@ -105,12 +103,10 @@ class TTSService: try: chunk_data = await AudioService.convert_audio( chunk_data, - 24000, output_format, writer, speed, chunk_text, - is_first_chunk=is_first and chunk_index == 0, is_last_chunk=is_last, normalizer=normalizer, ) @@ -139,12 +135,10 @@ class TTSService: try: chunk_data = await AudioService.convert_audio( chunk_data, - 24000, output_format, writer, speed, chunk_text, - is_first_chunk=is_first, normalizer=normalizer, is_last_chunk=is_last, ) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index b8dd761..dee66f9 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -70,16 +70,3 @@ def test_voice(): """Return a test voice name.""" return "voice1" - -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for the test session.""" - import asyncio - - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 3d53322..9351454 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -34,8 +34,11 @@ async def test_convert_to_wav(sample_audio): writer = StreamingAudioWriter("wav", sample_rate=24000) # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False + AudioChunk(audio_data), "wav", writer, is_last_chunk=False ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -52,8 +55,11 @@ async def test_convert_to_mp3(sample_audio): writer = StreamingAudioWriter("mp3", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "mp3", writer + AudioChunk(audio_data), "mp3", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -64,13 +70,17 @@ async def test_convert_to_mp3(sample_audio): @pytest.mark.asyncio async def test_convert_to_opus(sample_audio): """Test converting to Opus format""" + audio_data, sample_rate = sample_audio writer = StreamingAudioWriter("opus", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "opus",writer + AudioChunk(audio_data), "opus",writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -86,8 +96,11 @@ async def test_convert_to_flac(sample_audio): writer = StreamingAudioWriter("flac", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "flac", writer + AudioChunk(audio_data), "flac", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -103,8 +116,11 @@ async def test_convert_to_aac(sample_audio): writer = StreamingAudioWriter("aac", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "aac", writer + AudioChunk(audio_data), "aac", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -120,8 +136,11 @@ async def test_convert_to_pcm(sample_audio): writer = StreamingAudioWriter("pcm", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "pcm", writer + AudioChunk(audio_data), "pcm", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -131,12 +150,9 @@ async def test_convert_to_pcm(sample_audio): @pytest.mark.asyncio async def test_convert_to_invalid_format_raises_error(sample_audio): """Test that converting to an invalid format raises an error""" - audio_data, sample_rate = sample_audio + #audio_data, sample_rate = sample_audio with pytest.raises(ValueError, match="Unsupported format: invalid"): - StreamingAudioWriter("invalid", sample_rate=24000) - - with pytest.raises(ValueError, match="Format invalid not supported"): - await AudioService.convert_audio(audio_data, sample_rate, "invalid", None) + writer = StreamingAudioWriter("invalid", sample_rate=24000) @pytest.mark.asyncio @@ -150,8 +166,11 @@ async def test_normalization_wav(sample_audio): large_audio = audio_data * 1e5 # Write and finalize in one step for WAV audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True + AudioChunk(large_audio), "wav", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -167,7 +186,7 @@ async def test_normalization_pcm(sample_audio): # Create audio data outside int16 range large_audio = audio_data * 1e5 audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), sample_rate, "pcm", writer + AudioChunk(large_audio), "pcm", writer ) assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) @@ -197,8 +216,11 @@ async def test_different_sample_rates(sample_audio): writer = StreamingAudioWriter("wav", sample_rate=rate) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True + AudioChunk(audio_data), "wav", writer ) + + writer.close() + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) assert len(audio_chunk.output) > 0 @@ -213,7 +235,7 @@ async def test_buffer_position_after_conversion(sample_audio): # Write and finalize in one step for first conversion audio_chunk1 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), "wav", writer, is_last_chunk=True ) assert isinstance(audio_chunk1.output, bytes) assert isinstance(audio_chunk1, AudioChunk) @@ -222,7 +244,7 @@ async def test_buffer_position_after_conversion(sample_audio): writer = StreamingAudioWriter("wav", sample_rate=24000) audio_chunk2 = await AudioService.convert_audio( - AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True + AudioChunk(audio_data), "wav", writer, is_last_chunk=True ) assert isinstance(audio_chunk2.output, bytes) assert isinstance(audio_chunk2, AudioChunk) diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py index 527cb1f..0f89a6c 100644 --- a/api/tests/test_openai_endpoints.py +++ b/api/tests/test_openai_endpoints.py @@ -4,6 +4,8 @@ import os from typing import AsyncGenerator, Tuple from unittest.mock import AsyncMock, MagicMock, patch +from api.src.services.streaming_audio_writer import StreamingAudioWriter + from api.src.inference.base import AudioChunk import numpy as np import pytest @@ -159,10 +161,14 @@ async def test_stream_audio_chunks_client_disconnect(): speed=1.0, ) + writer = StreamingAudioWriter("mp3", 24000) + chunks = [] - async for chunk in stream_audio_chunks(mock_service, request, mock_request): + async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer): chunks.append(chunk) + writer.close() + assert len(chunks) == 0 # Should stop immediately due to disconnect @@ -483,7 +489,11 @@ async def test_streaming_initialization_error(): speed=1.0, ) + writer = StreamingAudioWriter("mp3", 24000) + with pytest.raises(RuntimeError) as exc: - async for _ in stream_audio_chunks(mock_service, request, MagicMock()): + async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer): pass + + writer.close() assert "Failed to initialize stream" in str(exc.value) diff --git a/pyproject.toml b/pyproject.toml index d0ff675..3b9e486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl", "inflect>=7.5.0", "phonemizer-fork>=3.3.2", - "av>=14.1.0", + "av>=14.2.0", ] [project.optional-dependencies] @@ -48,10 +48,10 @@ cpu = [ "torch==2.6.0", ] test = [ - "pytest==8.0.0", - "pytest-cov==4.1.0", + "pytest==8.3.5", + "pytest-cov==6.0.0", "httpx==0.26.0", - "pytest-asyncio==0.23.5", + "pytest-asyncio==0.25.3", "openai>=1.59.6", "tomli>=2.0.1", ] @@ -91,5 +91,5 @@ packages.find = {where = ["api/src"], namespaces = true} [tool.pytest.ini_options] testpaths = ["api/tests", "ui/tests"] python_files = ["test_*.py"] -addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc" -asyncio_mode = "strict" +addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace" +asyncio_mode = "auto"