diff --git a/api/src/inference/base.py b/api/src/inference/base.py index 905b206..0d18d22 100644 --- a/api/src/inference/base.py +++ b/api/src/inference/base.py @@ -11,10 +11,12 @@ class AudioChunk: def __init__(self, audio: np.ndarray, - word_timestamps: Optional[List]=[] + word_timestamps: Optional[List]=[], + output: Optional[Union[bytes,np.ndarray]]=b"" ): self.audio=audio self.word_timestamps=word_timestamps + self.output=output @staticmethod def combine(audio_chunk_list: List): diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 3b7d38b..6e09b73 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -209,10 +209,10 @@ async def create_captioned_speech( async def dual_output(): try: # Write chunks to temp file and stream - async for chunk,chunk_data in generator: - if chunk: # Skip empty chunks - await temp_writer.write(chunk) - base64_chunk= base64.b64encode(chunk).decode("utf-8") + async for chunk_data in generator: + if chunk_data.output: # Skip empty chunks + await temp_writer.write(chunk_data.output) + base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8") yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps) @@ -235,10 +235,10 @@ async def create_captioned_speech( async def single_output(): try: # Stream chunks - async for chunk,chunk_data in generator: - if chunk: # Skip empty chunks + async for chunk_data in generator: + if chunk_data.output: # Skip empty chunks # Encode the chunk bytes into base 64 - base64_chunk= base64.b64encode(chunk).decode("utf-8") + base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8") yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps) except Exception as e: @@ -258,7 +258,7 @@ async def create_captioned_speech( ) else: # Generate complete audio using public interface - _, audio_data = await tts_service.generate_audio( + audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, speed=request.speed, @@ -267,7 +267,7 @@ async def create_captioned_speech( lang_code=request.lang_code, ) - content, audio_data = await AudioService.convert_audio( + audio_data = await AudioService.convert_audio( audio_data, 24000, request.response_format, @@ -276,14 +276,14 @@ async def create_captioned_speech( ) # Convert to requested format with proper finalization - final, _ = await AudioService.convert_audio( + final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), 24000, request.response_format, is_first_chunk=False, is_last_chunk=True, ) - output=content+final + output=content.output + final.output base64_output= base64.b64encode(output).decode("utf-8") diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index bb4083d..73b3923 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -132,7 +132,7 @@ async def process_voices( async def stream_audio_chunks( tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request -) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]: +) -> AsyncGenerator[AudioChunk, None]: """Stream audio chunks as they're generated with client disconnect handling""" voice_name = await process_voices(request.voice, tts_service) @@ -144,7 +144,7 @@ async def stream_audio_chunks( try: logger.info(f"Starting audio generation with lang_code: {request.lang_code}") - async for chunk, chunk_data in tts_service.generate_audio_stream( + async for chunk_data in tts_service.generate_audio_stream( text=request.input, voice=voice_name, speed=request.speed, @@ -162,7 +162,7 @@ async def stream_audio_chunks( logger.info("Client disconnected, stopping audio generation") break - yield (chunk,chunk_data) + yield chunk_data except Exception as e: logger.error(f"Error in audio streaming: {str(e)}") # Let the exception propagate to trigger cleanup @@ -233,13 +233,13 @@ async def create_speech( async def dual_output(): try: # Write chunks to temp file and stream - async for chunk,chunk_data in generator: - if chunk: # Skip empty chunks - await temp_writer.write(chunk) + async for chunk_data in generator: + if chunk_data.output: # Skip empty chunks + await temp_writer.write(chunk_data.output) #if return_json: # yield chunk, chunk_data #else: - yield chunk + yield chunk_data.output # Finalize the temp file await temp_writer.finalize() @@ -260,9 +260,9 @@ async def create_speech( async def single_output(): try: # Stream chunks - async for chunk,chunk_data in generator: - if chunk: # Skip empty chunks - yield chunk + async for chunk_data in generator: + if chunk_data.output: # Skip empty chunks + yield chunk_data.output except Exception as e: logger.error(f"Error in single output streaming: {e}") raise @@ -280,7 +280,7 @@ async def create_speech( ) else: # Generate complete audio using public interface - _, audio_data = await tts_service.generate_audio( + audio_data = await tts_service.generate_audio( text=request.input, voice=voice_name, speed=request.speed, @@ -288,7 +288,7 @@ async def create_speech( lang_code=request.lang_code, ) - content, audio_data = await AudioService.convert_audio( + audio_data = await AudioService.convert_audio( audio_data, 24000, request.response_format, @@ -297,14 +297,14 @@ async def create_speech( ) # Convert to requested format with proper finalization - final, _ = await AudioService.convert_audio( + final = await AudioService.convert_audio( AudioChunk(np.array([], dtype=np.int16)), 24000, request.response_format, is_first_chunk=False, is_last_chunk=True, ) - output=content+final + output=audio_data.output + final.output return Response( content=output, media_type=content_type, diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 7fdb49f..222b5b5 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -120,7 +120,7 @@ class AudioService: is_first_chunk: bool = True, is_last_chunk: bool = False, normalizer: AudioNormalizer = None, - ) -> Tuple[bytes,AudioChunk]: + ) -> Tuple[AudioChunk]: """Convert audio data to specified format with streaming support Args: @@ -165,9 +165,13 @@ class AudioService: if is_last_chunk: final_data = writer.write_chunk(finalize=True) del AudioService._writers[writer_key] - return final_data if final_data else b"", audio_chunk - - return chunk_data if chunk_data else b"", audio_chunk + if final_data: + audio_chunk.output=final_data + return audio_chunk + + if chunk_data: + audio_chunk.output=chunk_data + return audio_chunk except Exception as e: logger.error(f"Error converting audio stream to {output_format}: {str(e)}") diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index f1eb627..d2aaaf2 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -54,7 +54,7 @@ class TTSService: normalizer: Optional[AudioNormalizer] = None, lang_code: Optional[str] = None, return_timestamps: Optional[bool] = False, - ) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]: + ) -> AsyncGenerator[AudioChunk, None]: """Process tokens into audio.""" async with self._chunk_semaphore: try: @@ -62,9 +62,9 @@ class TTSService: if is_last: # Skip format conversion for raw audio mode if not output_format: - yield np.array([], dtype=np.int16), AudioChunk(np.array([], dtype=np.int16)) + yield AudioChunk(np.array([], dtype=np.int16),output=b'') return - result, chunk_data = await AudioService.convert_audio( + chunk_data = await AudioService.convert_audio( AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking 24000, output_format, @@ -74,7 +74,7 @@ class TTSService: normalizer=normalizer, is_last_chunk=True, ) - yield result, chunk_data + yield chunk_data return # Skip empty chunks @@ -97,7 +97,7 @@ class TTSService: # For streaming, convert to bytes if output_format: try: - converted, chunk_data = await AudioService.convert_audio( + chunk_data = await AudioService.convert_audio( chunk_data, 24000, output_format, @@ -107,7 +107,7 @@ class TTSService: is_last_chunk=is_last, normalizer=normalizer, ) - yield converted, chunk_data + yield chunk_data except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: @@ -116,7 +116,7 @@ class TTSService: speed, is_last, normalizer) - yield chunk_data.audio, chunk_data + yield chunk_data else: # For legacy backends, load voice tensor @@ -138,7 +138,7 @@ class TTSService: # For streaming, convert to bytes if output_format: try: - converted, chunk_data = await AudioService.convert_audio( + chunk_data = await AudioService.convert_audio( chunk_data, 24000, output_format, @@ -148,7 +148,7 @@ class TTSService: normalizer=normalizer, is_last_chunk=is_last, ) - yield converted, chunk_data + yield chunk_data except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: @@ -157,7 +157,7 @@ class TTSService: speed, is_last, normalizer) - yield trimmed.audio, trimmed + yield trimmed except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") @@ -243,7 +243,7 @@ class TTSService: lang_code: Optional[str] = None, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), return_timestamps: Optional[bool] = False, - ) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: + ) -> AsyncGenerator[AudioChunk, None]: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() chunk_index = 0 @@ -267,7 +267,7 @@ class TTSService: async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options): try: # Process audio for chunk - async for result, chunk_data in self._process_chunk( + async for chunk_data in self._process_chunk( chunk_text, # Pass text for Kokoro V1 tokens, # Pass tokens for legacy backends voice_name, # Pass voice name @@ -287,8 +287,8 @@ class TTSService: current_offset+=len(chunk_data.audio) / 24000 - if result is not None: - yield result,chunk_data + if chunk_data.output is not None: + yield chunk_data chunk_index += 1 else: logger.warning( @@ -305,7 +305,7 @@ class TTSService: if chunk_index > 0: try: # Empty tokens list to finalize audio - async for result,chunk_data in self._process_chunk( + async for chunk_data in self._process_chunk( "", # Empty text [], # Empty tokens voice_name, @@ -317,8 +317,8 @@ class TTSService: normalizer=stream_normalizer, lang_code=pipeline_lang_code, # Pass lang_code ): - if result is not None: - yield result, chunk_data + if chunk_data.output is not None: + yield chunk_data except Exception as e: logger.error(f"Failed to finalize audio stream: {str(e)}") @@ -335,17 +335,17 @@ class TTSService: return_timestamps: bool = False, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), lang_code: Optional[str] = None, - ) -> Tuple[Tuple[np.ndarray,AudioChunk]]: + ) -> AudioChunk: """Generate complete audio for text using streaming internally.""" audio_data_chunks=[] try: - async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None): + async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None): audio_data_chunks.append(audio_stream_data) combined_audio_data=AudioChunk.combine(audio_data_chunks) - return combined_audio_data.audio,combined_audio_data + return combined_audio_data except Exception as e: logger.error(f"Error in audio generation: {str(e)}") raise diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 14c0e2a..6a15a62 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -31,83 +31,83 @@ async def test_convert_to_wav(sample_audio): """Test converting to WAV format""" audio_data, sample_rate = sample_audio # Write and finalize in one step for WAV - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # Check WAV header - assert result.startswith(b"RIFF") - assert b"WAVE" in result[:12] + assert audio_chunk.output.startswith(b"RIFF") + assert b"WAVE" in audio_chunk.output[:12] @pytest.mark.asyncio async def test_convert_to_mp3(sample_audio): """Test converting to MP3 format""" audio_data, sample_rate = sample_audio - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "mp3" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # Check MP3 header (ID3 or MPEG frame sync) - assert result.startswith(b"ID3") or result.startswith(b"\xff\xfb") + assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb") @pytest.mark.asyncio async def test_convert_to_opus(sample_audio): """Test converting to Opus format""" audio_data, sample_rate = sample_audio - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "opus" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # Check OGG header - assert result.startswith(b"OggS") + assert audio_chunk.output.startswith(b"OggS") @pytest.mark.asyncio async def test_convert_to_flac(sample_audio): """Test converting to FLAC format""" audio_data, sample_rate = sample_audio - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "flac" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # Check FLAC header - assert result.startswith(b"fLaC") + assert audio_chunk.output.startswith(b"fLaC") @pytest.mark.asyncio async def test_convert_to_aac(sample_audio): """Test converting to AAC format""" audio_data, sample_rate = sample_audio - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "aac" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # Check ADTS header (AAC) - assert result.startswith(b"\xff\xf0") or result.startswith(b"\xff\xf1") + assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1") @pytest.mark.asyncio async def test_convert_to_pcm(sample_audio): """Test converting to PCM format""" audio_data, sample_rate = sample_audio - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "pcm" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 # PCM is raw bytes, so no header to check @@ -126,12 +126,12 @@ async def test_normalization_wav(sample_audio): # Create audio data outside int16 range large_audio = audio_data * 1e5 # Write and finalize in one step for WAV - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 @pytest.mark.asyncio @@ -140,12 +140,12 @@ async def test_normalization_pcm(sample_audio): audio_data, sample_rate = sample_audio # Create audio data outside int16 range large_audio = audio_data * 1e5 - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(large_audio), sample_rate, "pcm" ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 @pytest.mark.asyncio @@ -164,12 +164,12 @@ async def test_different_sample_rates(sample_audio): sample_rates = [8000, 16000, 44100, 48000] for rate in sample_rates: - result, audio_chunk = await AudioService.convert_audio( + audio_chunk = await AudioService.convert_audio( AudioChunk(audio_data), rate, "wav", is_first_chunk=True, is_last_chunk=True ) - assert isinstance(result, bytes) + assert isinstance(audio_chunk.output, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(result) > 0 + assert len(audio_chunk.output) > 0 @pytest.mark.asyncio @@ -177,15 +177,15 @@ async def test_buffer_position_after_conversion(sample_audio): """Test that buffer position is reset after writing""" audio_data, sample_rate = sample_audio # Write and finalize in one step for first conversion - result1, audio_chunk = await AudioService.convert_audio( + audio_chunk1 = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True ) - assert isinstance(result1, bytes) - assert isinstance(audio_chunk, AudioChunk) + assert isinstance(audio_chunk1.output, bytes) + assert isinstance(audio_chunk1, AudioChunk) # Convert again to ensure buffer was properly reset - result2, audio_chunk = await AudioService.convert_audio( + audio_chunk2 = await AudioService.convert_audio( AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True ) - assert isinstance(result2, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(result1) == len(result2) + assert isinstance(audio_chunk2.output, bytes) + assert isinstance(audio_chunk2, AudioChunk) + assert len(audio_chunk1.output) == len(audio_chunk2.output) diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py index 82873e1..527cb1f 100644 --- a/api/tests/test_openai_endpoints.py +++ b/api/tests/test_openai_endpoints.py @@ -145,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect(): async def mock_stream(*args, **kwargs): for i in range(5): - yield (b"chunk",AudioChunk(np.ndarray([],np.int16))) + yield AudioChunk(np.ndarray([],np.int16),output=b"chunk") mock_service.generate_audio_stream = mock_stream mock_service.list_voices.return_value = ["test_voice"] @@ -160,7 +160,7 @@ async def test_stream_audio_chunks_client_disconnect(): ) chunks = [] - async for chunk, _ in stream_audio_chunks(mock_service, request, mock_request): + async for chunk in stream_audio_chunks(mock_service, request, mock_request): chunks.append(chunk) assert len(chunks) == 0 # Should stop immediately due to disconnect @@ -237,10 +237,10 @@ def mock_tts_service(mock_audio_bytes): """Mock TTS service for testing.""" with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get: service = AsyncMock(spec=TTSService) - service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16))) + service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16)) - async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: - yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16))) + async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]: + yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes) service.generate_audio_stream = mock_stream service.list_voices.return_value = ["test_voice", "voice1", "voice2"] @@ -257,8 +257,8 @@ def test_openai_speech_endpoint( ): """Test the OpenAI-compatible speech endpoint with basic MP3 generation""" # Configure mocks - mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16))) - mock_convert.return_value = (mock_audio_bytes,AudioChunk(np.zeros(1000,np.int16))) + mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16)) + mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes) response = client.post( "/v1/audio/speech",