From 5067a47ff9d510cbb4efc91f470914d7850b49dd Mon Sep 17 00:00:00 2001 From: Cong Nguyen <5560025+rampadc@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:05:37 +1100 Subject: [PATCH] wip: fix mp3 --- api/src/services/streaming_audio_writer.py | 37 ++++++--- api/src/services/tts_service.py | 96 ++++++++++++++++------ 2 files changed, 97 insertions(+), 36 deletions(-) diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 75d87b4..c605955 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -50,11 +50,15 @@ class StreamingAudioWriter: if finalize: if self.format != "pcm": - packets = self.stream.encode(None) - for packet in packets: - self.container.mux(packet) - - data=self.output_buffer.getvalue() + # Properly finalize the container + if hasattr(self, 'stream') and self.stream: + packets = self.stream.encode(None) + for packet in packets: + self.container.mux(packet) + # Important: Call close() on the container + self.container.close() + + data = self.output_buffer.getvalue() if hasattr(self, 'output_buffer') else b'' self.close() return data @@ -64,20 +68,33 @@ class StreamingAudioWriter: if self.format == "pcm": # Write raw bytes return audio_data.tobytes() + elif self.format == "mp3": + # For MP3, we need to handle streaming differently + frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo') + frame.sample_rate = self.sample_rate + frame.pts = self.pts + self.pts += frame.samples + + packets = self.stream.encode(frame) + for packet in packets: + self.container.mux(packet) + + # For MP3, just return an empty byte string to indicate processing occurred + # The complete MP3 file will be returned during finalization + return b'' else: frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo') frame.sample_rate=self.sample_rate - + frame.pts = self.pts self.pts += frame.samples - + packets = self.stream.encode(frame) for packet in packets: self.container.mux(packet) - + data = self.output_buffer.getvalue() self.output_buffer.seek(0) self.output_buffer.truncate(0) - return data - + return data diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 8a6bb42..46e95ae 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -245,6 +245,12 @@ class TTSService: stream_normalizer = AudioNormalizer() chunk_index = 0 current_offset = 0.0 + + # For MP3 format, we'll collect all audio chunks and encode at the end + is_mp3_format = output_format.lower() == "mp3" + all_audio_chunks = [] + all_word_timestamps = [] + try: # Get backend backend = self.model_manager.get_backend() @@ -268,13 +274,15 @@ class TTSService: voice_path, # Pass voice path speed, writer, - output_format, + # For MP3, we don't pass the output format during processing + None if is_mp3_format else output_format, is_first=(chunk_index == 0), is_last=False, # We'll update the last chunk later normalizer=stream_normalizer, lang_code=pipeline_lang_code, # Pass lang_code return_timestamps=return_timestamps, ): + # Update timestamps with current offset if chunk_data.word_timestamps is not None: for timestamp in chunk_data.word_timestamps: timestamp.start_time += current_offset @@ -282,42 +290,78 @@ class TTSService: current_offset += len(chunk_data.audio) / 24000 - if chunk_data.output is not None: - yield chunk_data - + if is_mp3_format: + # Collect audio for MP3 final encoding + if len(chunk_data.audio) > 0: + all_audio_chunks.append(chunk_data.audio) + if chunk_data.word_timestamps: + all_word_timestamps.extend(chunk_data.word_timestamps) else: - logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") + # For other formats, stream normally + if chunk_data.output is not None: + yield chunk_data + else: + logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") + chunk_index += 1 + except Exception as e: logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") continue # Only finalize if we successfully processed at least one chunk if chunk_index > 0: - try: - # Empty tokens list to finalize audio - async for chunk_data in self._process_chunk( - "", # Empty text - [], # Empty tokens - voice_name, - voice_path, - speed, - writer, - output_format, - is_first=False, - is_last=True, # Signal this is the last chunk - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code - ): - if chunk_data.output is not None: - yield chunk_data - except Exception as e: - logger.error(f"Failed to finalize audio stream: {str(e)}") + if is_mp3_format: + # For MP3, create a combined audio file + if all_audio_chunks: + try: + # Combine all audio chunks + combined_audio = np.concatenate(all_audio_chunks) + + # Create a fresh MP3 writer for final encoding + mp3_writer = StreamingAudioWriter("mp3", sample_rate=settings.sample_rate) + + # Write all audio at once and finalize + mp3_writer.write_chunk(combined_audio) + mp3_data = mp3_writer.write_chunk(finalize=True) + mp3_writer.close() + + # Create the final chunk with all audio and timestamps + result_chunk = AudioChunk( + audio=combined_audio, + word_timestamps=all_word_timestamps, + output=mp3_data + ) + yield result_chunk + + except Exception as e: + logger.error(f"Failed to create final MP3 file: {str(e)}") + raise RuntimeError(f"MP3 encoding failed: {str(e)}") + else: + # For other formats, just finalize the stream + try: + # Empty tokens list to finalize audio + async for chunk_data in self._process_chunk( + "", # Empty text + [], # Empty tokens + voice_name, + voice_path, + speed, + writer, + output_format, + is_first=False, + is_last=True, # Signal this is the last chunk + normalizer=stream_normalizer, + lang_code=pipeline_lang_code, # Pass lang_code + ): + if chunk_data.output is not None: + yield chunk_data + except Exception as e: + logger.error(f"Failed to finalize audio stream: {str(e)}") except Exception as e: - logger.error(f"Error in phoneme audio generation: {str(e)}") + logger.error(f"Error in audio generation: {str(e)}") raise e - async def generate_audio( self, text: str,