wip: fix mp3

This commit is contained in:
Cong Nguyen 2025-04-01 18:05:37 +11:00
parent 64ced408b7
commit 5067a47ff9
2 changed files with 97 additions and 36 deletions

View file

@ -50,11 +50,15 @@ class StreamingAudioWriter:
if finalize:
if self.format != "pcm":
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
data=self.output_buffer.getvalue()
# Properly finalize the container
if hasattr(self, 'stream') and self.stream:
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
# Important: Call close() on the container
self.container.close()
data = self.output_buffer.getvalue() if hasattr(self, 'output_buffer') else b''
self.close()
return data
@ -64,20 +68,33 @@ class StreamingAudioWriter:
if self.format == "pcm":
# Write raw bytes
return audio_data.tobytes()
elif self.format == "mp3":
# For MP3, we need to handle streaming differently
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate = self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
# For MP3, just return an empty byte string to indicate processing occurred
# The complete MP3 file will be returned during finalization
return b''
else:
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate=self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data
return data

View file

@ -245,6 +245,12 @@ class TTSService:
stream_normalizer = AudioNormalizer()
chunk_index = 0
current_offset = 0.0
# For MP3 format, we'll collect all audio chunks and encode at the end
is_mp3_format = output_format.lower() == "mp3"
all_audio_chunks = []
all_word_timestamps = []
try:
# Get backend
backend = self.model_manager.get_backend()
@ -268,13 +274,15 @@ class TTSService:
voice_path, # Pass voice path
speed,
writer,
output_format,
# For MP3, we don't pass the output format during processing
None if is_mp3_format else output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
# Update timestamps with current offset
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
@ -282,42 +290,78 @@ class TTSService:
current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
if is_mp3_format:
# Collect audio for MP3 final encoding
if len(chunk_data.audio) > 0:
all_audio_chunks.append(chunk_data.audio)
if chunk_data.word_timestamps:
all_word_timestamps.extend(chunk_data.word_timestamps)
else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
# For other formats, stream normally
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
chunk_index += 1
except Exception as e:
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
voice_path,
speed,
writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
if is_mp3_format:
# For MP3, create a combined audio file
if all_audio_chunks:
try:
# Combine all audio chunks
combined_audio = np.concatenate(all_audio_chunks)
# Create a fresh MP3 writer for final encoding
mp3_writer = StreamingAudioWriter("mp3", sample_rate=settings.sample_rate)
# Write all audio at once and finalize
mp3_writer.write_chunk(combined_audio)
mp3_data = mp3_writer.write_chunk(finalize=True)
mp3_writer.close()
# Create the final chunk with all audio and timestamps
result_chunk = AudioChunk(
audio=combined_audio,
word_timestamps=all_word_timestamps,
output=mp3_data
)
yield result_chunk
except Exception as e:
logger.error(f"Failed to create final MP3 file: {str(e)}")
raise RuntimeError(f"MP3 encoding failed: {str(e)}")
else:
# For other formats, just finalize the stream
try:
# Empty tokens list to finalize audio
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
voice_path,
speed,
writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
logger.error(f"Error in audio generation: {str(e)}")
raise e
async def generate_audio(
self,
text: str,