Simplify code so erverything uses AudioChunks

This commit is contained in:
Fireblade 2025-02-16 15:37:01 -05:00
parent 9c0e328318
commit e3dc959775
7 changed files with 103 additions and 97 deletions

View file

@ -11,10 +11,12 @@ class AudioChunk:
def __init__(self, def __init__(self,
audio: np.ndarray, audio: np.ndarray,
word_timestamps: Optional[List]=[] word_timestamps: Optional[List]=[],
output: Optional[Union[bytes,np.ndarray]]=b""
): ):
self.audio=audio self.audio=audio
self.word_timestamps=word_timestamps self.word_timestamps=word_timestamps
self.output=output
@staticmethod @staticmethod
def combine(audio_chunk_list: List): def combine(audio_chunk_list: List):

View file

@ -209,10 +209,10 @@ async def create_captioned_speech(
async def dual_output(): async def dual_output():
try: try:
# Write chunks to temp file and stream # Write chunks to temp file and stream
async for chunk,chunk_data in generator: async for chunk_data in generator:
if chunk: # Skip empty chunks if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk) await temp_writer.write(chunk_data.output)
base64_chunk= base64.b64encode(chunk).decode("utf-8") base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps) yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
@ -235,10 +235,10 @@ async def create_captioned_speech(
async def single_output(): async def single_output():
try: try:
# Stream chunks # Stream chunks
async for chunk,chunk_data in generator: async for chunk_data in generator:
if chunk: # Skip empty chunks if chunk_data.output: # Skip empty chunks
# Encode the chunk bytes into base 64 # Encode the chunk bytes into base 64
base64_chunk= base64.b64encode(chunk).decode("utf-8") base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps) yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
except Exception as e: except Exception as e:
@ -258,7 +258,7 @@ async def create_captioned_speech(
) )
else: else:
# Generate complete audio using public interface # Generate complete audio using public interface
_, audio_data = await tts_service.generate_audio( audio_data = await tts_service.generate_audio(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
@ -267,7 +267,7 @@ async def create_captioned_speech(
lang_code=request.lang_code, lang_code=request.lang_code,
) )
content, audio_data = await AudioService.convert_audio( audio_data = await AudioService.convert_audio(
audio_data, audio_data,
24000, 24000,
request.response_format, request.response_format,
@ -276,14 +276,14 @@ async def create_captioned_speech(
) )
# Convert to requested format with proper finalization # Convert to requested format with proper finalization
final, _ = await AudioService.convert_audio( final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)), AudioChunk(np.array([], dtype=np.int16)),
24000, 24000,
request.response_format, request.response_format,
is_first_chunk=False, is_first_chunk=False,
is_last_chunk=True, is_last_chunk=True,
) )
output=content+final output=content.output + final.output
base64_output= base64.b64encode(output).decode("utf-8") base64_output= base64.b64encode(output).decode("utf-8")

View file

@ -132,7 +132,7 @@ async def process_voices(
async def stream_audio_chunks( async def stream_audio_chunks(
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]: ) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling""" """Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service) voice_name = await process_voices(request.voice, tts_service)
@ -144,7 +144,7 @@ async def stream_audio_chunks(
try: try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}") logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk, chunk_data in tts_service.generate_audio_stream( async for chunk_data in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
@ -162,7 +162,7 @@ async def stream_audio_chunks(
logger.info("Client disconnected, stopping audio generation") logger.info("Client disconnected, stopping audio generation")
break break
yield (chunk,chunk_data) yield chunk_data
except Exception as e: except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}") logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup # Let the exception propagate to trigger cleanup
@ -233,13 +233,13 @@ async def create_speech(
async def dual_output(): async def dual_output():
try: try:
# Write chunks to temp file and stream # Write chunks to temp file and stream
async for chunk,chunk_data in generator: async for chunk_data in generator:
if chunk: # Skip empty chunks if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk) await temp_writer.write(chunk_data.output)
#if return_json: #if return_json:
# yield chunk, chunk_data # yield chunk, chunk_data
#else: #else:
yield chunk yield chunk_data.output
# Finalize the temp file # Finalize the temp file
await temp_writer.finalize() await temp_writer.finalize()
@ -260,9 +260,9 @@ async def create_speech(
async def single_output(): async def single_output():
try: try:
# Stream chunks # Stream chunks
async for chunk,chunk_data in generator: async for chunk_data in generator:
if chunk: # Skip empty chunks if chunk_data.output: # Skip empty chunks
yield chunk yield chunk_data.output
except Exception as e: except Exception as e:
logger.error(f"Error in single output streaming: {e}") logger.error(f"Error in single output streaming: {e}")
raise raise
@ -280,7 +280,7 @@ async def create_speech(
) )
else: else:
# Generate complete audio using public interface # Generate complete audio using public interface
_, audio_data = await tts_service.generate_audio( audio_data = await tts_service.generate_audio(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
@ -288,7 +288,7 @@ async def create_speech(
lang_code=request.lang_code, lang_code=request.lang_code,
) )
content, audio_data = await AudioService.convert_audio( audio_data = await AudioService.convert_audio(
audio_data, audio_data,
24000, 24000,
request.response_format, request.response_format,
@ -297,14 +297,14 @@ async def create_speech(
) )
# Convert to requested format with proper finalization # Convert to requested format with proper finalization
final, _ = await AudioService.convert_audio( final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)), AudioChunk(np.array([], dtype=np.int16)),
24000, 24000,
request.response_format, request.response_format,
is_first_chunk=False, is_first_chunk=False,
is_last_chunk=True, is_last_chunk=True,
) )
output=content+final output=audio_data.output + final.output
return Response( return Response(
content=output, content=output,
media_type=content_type, media_type=content_type,

View file

@ -120,7 +120,7 @@ class AudioService:
is_first_chunk: bool = True, is_first_chunk: bool = True,
is_last_chunk: bool = False, is_last_chunk: bool = False,
normalizer: AudioNormalizer = None, normalizer: AudioNormalizer = None,
) -> Tuple[bytes,AudioChunk]: ) -> Tuple[AudioChunk]:
"""Convert audio data to specified format with streaming support """Convert audio data to specified format with streaming support
Args: Args:
@ -165,9 +165,13 @@ class AudioService:
if is_last_chunk: if is_last_chunk:
final_data = writer.write_chunk(finalize=True) final_data = writer.write_chunk(finalize=True)
del AudioService._writers[writer_key] del AudioService._writers[writer_key]
return final_data if final_data else b"", audio_chunk if final_data:
audio_chunk.output=final_data
return audio_chunk
return chunk_data if chunk_data else b"", audio_chunk if chunk_data:
audio_chunk.output=chunk_data
return audio_chunk
except Exception as e: except Exception as e:
logger.error(f"Error converting audio stream to {output_format}: {str(e)}") logger.error(f"Error converting audio stream to {output_format}: {str(e)}")

View file

@ -54,7 +54,7 @@ class TTSService:
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]: ) -> AsyncGenerator[AudioChunk, None]:
"""Process tokens into audio.""" """Process tokens into audio."""
async with self._chunk_semaphore: async with self._chunk_semaphore:
try: try:
@ -62,9 +62,9 @@ class TTSService:
if is_last: if is_last:
# Skip format conversion for raw audio mode # Skip format conversion for raw audio mode
if not output_format: if not output_format:
yield np.array([], dtype=np.int16), AudioChunk(np.array([], dtype=np.int16)) yield AudioChunk(np.array([], dtype=np.int16),output=b'')
return return
result, chunk_data = await AudioService.convert_audio( chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
24000, 24000,
output_format, output_format,
@ -74,7 +74,7 @@ class TTSService:
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=True, is_last_chunk=True,
) )
yield result, chunk_data yield chunk_data
return return
# Skip empty chunks # Skip empty chunks
@ -97,7 +97,7 @@ class TTSService:
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
converted, chunk_data = await AudioService.convert_audio( chunk_data = await AudioService.convert_audio(
chunk_data, chunk_data,
24000, 24000,
output_format, output_format,
@ -107,7 +107,7 @@ class TTSService:
is_last_chunk=is_last, is_last_chunk=is_last,
normalizer=normalizer, normalizer=normalizer,
) )
yield converted, chunk_data yield chunk_data
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
@ -116,7 +116,7 @@ class TTSService:
speed, speed,
is_last, is_last,
normalizer) normalizer)
yield chunk_data.audio, chunk_data yield chunk_data
else: else:
# For legacy backends, load voice tensor # For legacy backends, load voice tensor
@ -138,7 +138,7 @@ class TTSService:
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
converted, chunk_data = await AudioService.convert_audio( chunk_data = await AudioService.convert_audio(
chunk_data, chunk_data,
24000, 24000,
output_format, output_format,
@ -148,7 +148,7 @@ class TTSService:
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=is_last, is_last_chunk=is_last,
) )
yield converted, chunk_data yield chunk_data
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
@ -157,7 +157,7 @@ class TTSService:
speed, speed,
is_last, is_last,
normalizer) normalizer)
yield trimmed.audio, trimmed yield trimmed
except Exception as e: except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}") logger.error(f"Failed to process tokens: {str(e)}")
@ -243,7 +243,7 @@ class TTSService:
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: ) -> AsyncGenerator[AudioChunk, None]:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
chunk_index = 0 chunk_index = 0
@ -267,7 +267,7 @@ class TTSService:
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options): async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
try: try:
# Process audio for chunk # Process audio for chunk
async for result, chunk_data in self._process_chunk( async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1 chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name voice_name, # Pass voice name
@ -287,8 +287,8 @@ class TTSService:
current_offset+=len(chunk_data.audio) / 24000 current_offset+=len(chunk_data.audio) / 24000
if result is not None: if chunk_data.output is not None:
yield result,chunk_data yield chunk_data
chunk_index += 1 chunk_index += 1
else: else:
logger.warning( logger.warning(
@ -305,7 +305,7 @@ class TTSService:
if chunk_index > 0: if chunk_index > 0:
try: try:
# Empty tokens list to finalize audio # Empty tokens list to finalize audio
async for result,chunk_data in self._process_chunk( async for chunk_data in self._process_chunk(
"", # Empty text "", # Empty text
[], # Empty tokens [], # Empty tokens
voice_name, voice_name,
@ -317,8 +317,8 @@ class TTSService:
normalizer=stream_normalizer, normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
): ):
if result is not None: if chunk_data.output is not None:
yield result, chunk_data yield chunk_data
except Exception as e: except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}") logger.error(f"Failed to finalize audio stream: {str(e)}")
@ -335,17 +335,17 @@ class TTSService:
return_timestamps: bool = False, return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> Tuple[Tuple[np.ndarray,AudioChunk]]: ) -> AudioChunk:
"""Generate complete audio for text using streaming internally.""" """Generate complete audio for text using streaming internally."""
audio_data_chunks=[] audio_data_chunks=[]
try: try:
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None): async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
audio_data_chunks.append(audio_stream_data) audio_data_chunks.append(audio_stream_data)
combined_audio_data=AudioChunk.combine(audio_data_chunks) combined_audio_data=AudioChunk.combine(audio_data_chunks)
return combined_audio_data.audio,combined_audio_data return combined_audio_data
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation: {str(e)}") logger.error(f"Error in audio generation: {str(e)}")
raise raise

View file

@ -31,83 +31,83 @@ async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format""" """Test converting to WAV format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
# Write and finalize in one step for WAV # Write and finalize in one step for WAV
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# Check WAV header # Check WAV header
assert result.startswith(b"RIFF") assert audio_chunk.output.startswith(b"RIFF")
assert b"WAVE" in result[:12] assert b"WAVE" in audio_chunk.output[:12]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio): async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format""" """Test converting to MP3 format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "mp3" AudioChunk(audio_data), sample_rate, "mp3"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# Check MP3 header (ID3 or MPEG frame sync) # Check MP3 header (ID3 or MPEG frame sync)
assert result.startswith(b"ID3") or result.startswith(b"\xff\xfb") assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_to_opus(sample_audio): async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format""" """Test converting to Opus format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "opus" AudioChunk(audio_data), sample_rate, "opus"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# Check OGG header # Check OGG header
assert result.startswith(b"OggS") assert audio_chunk.output.startswith(b"OggS")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_to_flac(sample_audio): async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format""" """Test converting to FLAC format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "flac" AudioChunk(audio_data), sample_rate, "flac"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# Check FLAC header # Check FLAC header
assert result.startswith(b"fLaC") assert audio_chunk.output.startswith(b"fLaC")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_to_aac(sample_audio): async def test_convert_to_aac(sample_audio):
"""Test converting to AAC format""" """Test converting to AAC format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "aac" AudioChunk(audio_data), sample_rate, "aac"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# Check ADTS header (AAC) # Check ADTS header (AAC)
assert result.startswith(b"\xff\xf0") or result.startswith(b"\xff\xf1") assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio): async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format""" """Test converting to PCM format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "pcm" AudioChunk(audio_data), sample_rate, "pcm"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
# PCM is raw bytes, so no header to check # PCM is raw bytes, so no header to check
@ -126,12 +126,12 @@ async def test_normalization_wav(sample_audio):
# Create audio data outside int16 range # Create audio data outside int16 range
large_audio = audio_data * 1e5 large_audio = audio_data * 1e5
# Write and finalize in one step for WAV # Write and finalize in one step for WAV
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -140,12 +140,12 @@ async def test_normalization_pcm(sample_audio):
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
# Create audio data outside int16 range # Create audio data outside int16 range
large_audio = audio_data * 1e5 large_audio = audio_data * 1e5
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "pcm" AudioChunk(large_audio), sample_rate, "pcm"
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -164,12 +164,12 @@ async def test_different_sample_rates(sample_audio):
sample_rates = [8000, 16000, 44100, 48000] sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates: for rate in sample_rates:
result, audio_chunk = await AudioService.convert_audio( audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), rate, "wav", is_first_chunk=True, is_last_chunk=True AudioChunk(audio_data), rate, "wav", is_first_chunk=True, is_last_chunk=True
) )
assert isinstance(result, bytes) assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0 assert len(audio_chunk.output) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -177,15 +177,15 @@ async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing""" """Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
# Write and finalize in one step for first conversion # Write and finalize in one step for first conversion
result1, audio_chunk = await AudioService.convert_audio( audio_chunk1 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
) )
assert isinstance(result1, bytes) assert isinstance(audio_chunk1.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk1, AudioChunk)
# Convert again to ensure buffer was properly reset # Convert again to ensure buffer was properly reset
result2, audio_chunk = await AudioService.convert_audio( audio_chunk2 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
) )
assert isinstance(result2, bytes) assert isinstance(audio_chunk2.output, bytes)
assert isinstance(audio_chunk, AudioChunk) assert isinstance(audio_chunk2, AudioChunk)
assert len(result1) == len(result2) assert len(audio_chunk1.output) == len(audio_chunk2.output)

View file

@ -145,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for i in range(5): for i in range(5):
yield (b"chunk",AudioChunk(np.ndarray([],np.int16))) yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
mock_service.generate_audio_stream = mock_stream mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"] mock_service.list_voices.return_value = ["test_voice"]
@ -160,7 +160,7 @@ async def test_stream_audio_chunks_client_disconnect():
) )
chunks = [] chunks = []
async for chunk, _ in stream_audio_chunks(mock_service, request, mock_request): async for chunk in stream_audio_chunks(mock_service, request, mock_request):
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 0 # Should stop immediately due to disconnect assert len(chunks) == 0 # Should stop immediately due to disconnect
@ -237,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing.""" """Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get: with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService) service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16))) service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16))) yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
service.generate_audio_stream = mock_stream service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"] service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@ -257,8 +257,8 @@ def test_openai_speech_endpoint(
): ):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation""" """Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks # Configure mocks
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16))) mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
mock_convert.return_value = (mock_audio_bytes,AudioChunk(np.zeros(1000,np.int16))) mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
response = client.post( response = client.post(
"/v1/audio/speech", "/v1/audio/speech",