mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
wip: fix mp3
This commit is contained in:
parent
64ced408b7
commit
5067a47ff9
2 changed files with 97 additions and 36 deletions
|
@ -50,11 +50,15 @@ class StreamingAudioWriter:
|
|||
|
||||
if finalize:
|
||||
if self.format != "pcm":
|
||||
packets = self.stream.encode(None)
|
||||
for packet in packets:
|
||||
self.container.mux(packet)
|
||||
# Properly finalize the container
|
||||
if hasattr(self, 'stream') and self.stream:
|
||||
packets = self.stream.encode(None)
|
||||
for packet in packets:
|
||||
self.container.mux(packet)
|
||||
# Important: Call close() on the container
|
||||
self.container.close()
|
||||
|
||||
data=self.output_buffer.getvalue()
|
||||
data = self.output_buffer.getvalue() if hasattr(self, 'output_buffer') else b''
|
||||
self.close()
|
||||
return data
|
||||
|
||||
|
@ -64,6 +68,20 @@ class StreamingAudioWriter:
|
|||
if self.format == "pcm":
|
||||
# Write raw bytes
|
||||
return audio_data.tobytes()
|
||||
elif self.format == "mp3":
|
||||
# For MP3, we need to handle streaming differently
|
||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
||||
frame.sample_rate = self.sample_rate
|
||||
frame.pts = self.pts
|
||||
self.pts += frame.samples
|
||||
|
||||
packets = self.stream.encode(frame)
|
||||
for packet in packets:
|
||||
self.container.mux(packet)
|
||||
|
||||
# For MP3, just return an empty byte string to indicate processing occurred
|
||||
# The complete MP3 file will be returned during finalization
|
||||
return b''
|
||||
else:
|
||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
||||
frame.sample_rate=self.sample_rate
|
||||
|
@ -80,4 +98,3 @@ class StreamingAudioWriter:
|
|||
self.output_buffer.seek(0)
|
||||
self.output_buffer.truncate(0)
|
||||
return data
|
||||
|
||||
|
|
|
@ -245,6 +245,12 @@ class TTSService:
|
|||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
current_offset = 0.0
|
||||
|
||||
# For MP3 format, we'll collect all audio chunks and encode at the end
|
||||
is_mp3_format = output_format.lower() == "mp3"
|
||||
all_audio_chunks = []
|
||||
all_word_timestamps = []
|
||||
|
||||
try:
|
||||
# Get backend
|
||||
backend = self.model_manager.get_backend()
|
||||
|
@ -268,13 +274,15 @@ class TTSService:
|
|||
voice_path, # Pass voice path
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
# For MP3, we don't pass the output format during processing
|
||||
None if is_mp3_format else output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
# Update timestamps with current offset
|
||||
if chunk_data.word_timestamps is not None:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp.start_time += current_offset
|
||||
|
@ -282,42 +290,78 @@ class TTSService:
|
|||
|
||||
current_offset += len(chunk_data.audio) / 24000
|
||||
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
|
||||
if is_mp3_format:
|
||||
# Collect audio for MP3 final encoding
|
||||
if len(chunk_data.audio) > 0:
|
||||
all_audio_chunks.append(chunk_data.audio)
|
||||
if chunk_data.word_timestamps:
|
||||
all_word_timestamps.extend(chunk_data.word_timestamps)
|
||||
else:
|
||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||
# For other formats, stream normally
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
else:
|
||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
||||
continue
|
||||
|
||||
# Only finalize if we successfully processed at least one chunk
|
||||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
if is_mp3_format:
|
||||
# For MP3, create a combined audio file
|
||||
if all_audio_chunks:
|
||||
try:
|
||||
# Combine all audio chunks
|
||||
combined_audio = np.concatenate(all_audio_chunks)
|
||||
|
||||
# Create a fresh MP3 writer for final encoding
|
||||
mp3_writer = StreamingAudioWriter("mp3", sample_rate=settings.sample_rate)
|
||||
|
||||
# Write all audio at once and finalize
|
||||
mp3_writer.write_chunk(combined_audio)
|
||||
mp3_data = mp3_writer.write_chunk(finalize=True)
|
||||
mp3_writer.close()
|
||||
|
||||
# Create the final chunk with all audio and timestamps
|
||||
result_chunk = AudioChunk(
|
||||
audio=combined_audio,
|
||||
word_timestamps=all_word_timestamps,
|
||||
output=mp3_data
|
||||
)
|
||||
yield result_chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create final MP3 file: {str(e)}")
|
||||
raise RuntimeError(f"MP3 encoding failed: {str(e)}")
|
||||
else:
|
||||
# For other formats, just finalize the stream
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
|
|
Loading…
Add table
Reference in a new issue