Simplify code so erverything uses AudioChunks

This commit is contained in:
Fireblade 2025-02-16 15:37:01 -05:00
parent 9c0e328318
commit e3dc959775
7 changed files with 103 additions and 97 deletions

View file

@ -11,10 +11,12 @@ class AudioChunk:
def __init__(self,
audio: np.ndarray,
word_timestamps: Optional[List]=[]
word_timestamps: Optional[List]=[],
output: Optional[Union[bytes,np.ndarray]]=b""
):
self.audio=audio
self.word_timestamps=word_timestamps
self.output=output
@staticmethod
def combine(audio_chunk_list: List):

View file

@ -209,10 +209,10 @@ async def create_captioned_speech(
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
base64_chunk= base64.b64encode(chunk).decode("utf-8")
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
@ -235,10 +235,10 @@ async def create_captioned_speech(
async def single_output():
try:
# Stream chunks
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
# Encode the chunk bytes into base 64
base64_chunk= base64.b64encode(chunk).decode("utf-8")
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
except Exception as e:
@ -258,7 +258,7 @@ async def create_captioned_speech(
)
else:
# Generate complete audio using public interface
_, audio_data = await tts_service.generate_audio(
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
speed=request.speed,
@ -267,7 +267,7 @@ async def create_captioned_speech(
lang_code=request.lang_code,
)
content, audio_data = await AudioService.convert_audio(
audio_data = await AudioService.convert_audio(
audio_data,
24000,
request.response_format,
@ -276,14 +276,14 @@ async def create_captioned_speech(
)
# Convert to requested format with proper finalization
final, _ = await AudioService.convert_audio(
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
24000,
request.response_format,
is_first_chunk=False,
is_last_chunk=True,
)
output=content+final
output=content.output + final.output
base64_output= base64.b64encode(output).decode("utf-8")

View file

@ -132,7 +132,7 @@ async def process_voices(
async def stream_audio_chunks(
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service)
@ -144,7 +144,7 @@ async def stream_audio_chunks(
try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk, chunk_data in tts_service.generate_audio_stream(
async for chunk_data in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
speed=request.speed,
@ -162,7 +162,7 @@ async def stream_audio_chunks(
logger.info("Client disconnected, stopping audio generation")
break
yield (chunk,chunk_data)
yield chunk_data
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
@ -233,13 +233,13 @@ async def create_speech(
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
#if return_json:
# yield chunk, chunk_data
#else:
yield chunk
yield chunk_data.output
# Finalize the temp file
await temp_writer.finalize()
@ -260,9 +260,9 @@ async def create_speech(
async def single_output():
try:
# Stream chunks
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
yield chunk
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
yield chunk_data.output
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
raise
@ -280,7 +280,7 @@ async def create_speech(
)
else:
# Generate complete audio using public interface
_, audio_data = await tts_service.generate_audio(
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
speed=request.speed,
@ -288,7 +288,7 @@ async def create_speech(
lang_code=request.lang_code,
)
content, audio_data = await AudioService.convert_audio(
audio_data = await AudioService.convert_audio(
audio_data,
24000,
request.response_format,
@ -297,14 +297,14 @@ async def create_speech(
)
# Convert to requested format with proper finalization
final, _ = await AudioService.convert_audio(
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
24000,
request.response_format,
is_first_chunk=False,
is_last_chunk=True,
)
output=content+final
output=audio_data.output + final.output
return Response(
content=output,
media_type=content_type,

View file

@ -120,7 +120,7 @@ class AudioService:
is_first_chunk: bool = True,
is_last_chunk: bool = False,
normalizer: AudioNormalizer = None,
) -> Tuple[bytes,AudioChunk]:
) -> Tuple[AudioChunk]:
"""Convert audio data to specified format with streaming support
Args:
@ -165,9 +165,13 @@ class AudioService:
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
del AudioService._writers[writer_key]
return final_data if final_data else b"", audio_chunk
if final_data:
audio_chunk.output=final_data
return audio_chunk
return chunk_data if chunk_data else b"", audio_chunk
if chunk_data:
audio_chunk.output=chunk_data
return audio_chunk
except Exception as e:
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")

View file

@ -54,7 +54,7 @@ class TTSService:
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
) -> AsyncGenerator[AudioChunk, None]:
"""Process tokens into audio."""
async with self._chunk_semaphore:
try:
@ -62,9 +62,9 @@ class TTSService:
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
yield np.array([], dtype=np.int16), AudioChunk(np.array([], dtype=np.int16))
yield AudioChunk(np.array([], dtype=np.int16),output=b'')
return
result, chunk_data = await AudioService.convert_audio(
chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
24000,
output_format,
@ -74,7 +74,7 @@ class TTSService:
normalizer=normalizer,
is_last_chunk=True,
)
yield result, chunk_data
yield chunk_data
return
# Skip empty chunks
@ -97,7 +97,7 @@ class TTSService:
# For streaming, convert to bytes
if output_format:
try:
converted, chunk_data = await AudioService.convert_audio(
chunk_data = await AudioService.convert_audio(
chunk_data,
24000,
output_format,
@ -107,7 +107,7 @@ class TTSService:
is_last_chunk=is_last,
normalizer=normalizer,
)
yield converted, chunk_data
yield chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
@ -116,7 +116,7 @@ class TTSService:
speed,
is_last,
normalizer)
yield chunk_data.audio, chunk_data
yield chunk_data
else:
# For legacy backends, load voice tensor
@ -138,7 +138,7 @@ class TTSService:
# For streaming, convert to bytes
if output_format:
try:
converted, chunk_data = await AudioService.convert_audio(
chunk_data = await AudioService.convert_audio(
chunk_data,
24000,
output_format,
@ -148,7 +148,7 @@ class TTSService:
normalizer=normalizer,
is_last_chunk=is_last,
)
yield converted, chunk_data
yield chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
@ -157,7 +157,7 @@ class TTSService:
speed,
is_last,
normalizer)
yield trimmed.audio, trimmed
yield trimmed
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
@ -243,7 +243,7 @@ class TTSService:
lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
) -> AsyncGenerator[AudioChunk, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
chunk_index = 0
@ -267,7 +267,7 @@ class TTSService:
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
try:
# Process audio for chunk
async for result, chunk_data in self._process_chunk(
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
@ -287,8 +287,8 @@ class TTSService:
current_offset+=len(chunk_data.audio) / 24000
if result is not None:
yield result,chunk_data
if chunk_data.output is not None:
yield chunk_data
chunk_index += 1
else:
logger.warning(
@ -305,7 +305,7 @@ class TTSService:
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
async for result,chunk_data in self._process_chunk(
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
@ -317,8 +317,8 @@ class TTSService:
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
if result is not None:
yield result, chunk_data
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
@ -335,17 +335,17 @@ class TTSService:
return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None,
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
) -> AudioChunk:
"""Generate complete audio for text using streaming internally."""
audio_data_chunks=[]
try:
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
audio_data_chunks.append(audio_stream_data)
combined_audio_data=AudioChunk.combine(audio_data_chunks)
return combined_audio_data.audio,combined_audio_data
return combined_audio_data
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise

View file

@ -31,83 +31,83 @@ async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
# Write and finalize in one step for WAV
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# Check WAV header
assert result.startswith(b"RIFF")
assert b"WAVE" in result[:12]
assert audio_chunk.output.startswith(b"RIFF")
assert b"WAVE" in audio_chunk.output[:12]
@pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "mp3"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert result.startswith(b"ID3") or result.startswith(b"\xff\xfb")
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
@pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "opus"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# Check OGG header
assert result.startswith(b"OggS")
assert audio_chunk.output.startswith(b"OggS")
@pytest.mark.asyncio
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "flac"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# Check FLAC header
assert result.startswith(b"fLaC")
assert audio_chunk.output.startswith(b"fLaC")
@pytest.mark.asyncio
async def test_convert_to_aac(sample_audio):
"""Test converting to AAC format"""
audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "aac"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# Check ADTS header (AAC)
assert result.startswith(b"\xff\xf0") or result.startswith(b"\xff\xf1")
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
@pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "pcm"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
# PCM is raw bytes, so no header to check
@ -126,12 +126,12 @@ async def test_normalization_wav(sample_audio):
# Create audio data outside int16 range
large_audio = audio_data * 1e5
# Write and finalize in one step for WAV
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
@ -140,12 +140,12 @@ async def test_normalization_pcm(sample_audio):
audio_data, sample_rate = sample_audio
# Create audio data outside int16 range
large_audio = audio_data * 1e5
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), sample_rate, "pcm"
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
@ -164,12 +164,12 @@ async def test_different_sample_rates(sample_audio):
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
result, audio_chunk = await AudioService.convert_audio(
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(result, bytes)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result) > 0
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
@ -177,15 +177,15 @@ async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
# Write and finalize in one step for first conversion
result1, audio_chunk = await AudioService.convert_audio(
audio_chunk1 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(result1, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert isinstance(audio_chunk1.output, bytes)
assert isinstance(audio_chunk1, AudioChunk)
# Convert again to ensure buffer was properly reset
result2, audio_chunk = await AudioService.convert_audio(
audio_chunk2 = await AudioService.convert_audio(
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(result2, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(result1) == len(result2)
assert isinstance(audio_chunk2.output, bytes)
assert isinstance(audio_chunk2, AudioChunk)
assert len(audio_chunk1.output) == len(audio_chunk2.output)

View file

@ -145,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs):
for i in range(5):
yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
@ -160,7 +160,7 @@ async def test_stream_audio_chunks_client_disconnect():
)
chunks = []
async for chunk, _ in stream_audio_chunks(mock_service, request, mock_request):
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
chunks.append(chunk)
assert len(chunks) == 0 # Should stop immediately due to disconnect
@ -237,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@ -257,8 +257,8 @@ def test_openai_speech_endpoint(
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
mock_convert.return_value = (mock_audio_bytes,AudioChunk(np.zeros(1000,np.int16)))
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
response = client.post(
"/v1/audio/speech",