mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
wip: fix mp3
This commit is contained in:
parent
64ced408b7
commit
5067a47ff9
2 changed files with 97 additions and 36 deletions
|
@ -50,11 +50,15 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
|
# Properly finalize the container
|
||||||
|
if hasattr(self, 'stream') and self.stream:
|
||||||
packets = self.stream.encode(None)
|
packets = self.stream.encode(None)
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
self.container.mux(packet)
|
self.container.mux(packet)
|
||||||
|
# Important: Call close() on the container
|
||||||
|
self.container.close()
|
||||||
|
|
||||||
data=self.output_buffer.getvalue()
|
data = self.output_buffer.getvalue() if hasattr(self, 'output_buffer') else b''
|
||||||
self.close()
|
self.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -64,6 +68,20 @@ class StreamingAudioWriter:
|
||||||
if self.format == "pcm":
|
if self.format == "pcm":
|
||||||
# Write raw bytes
|
# Write raw bytes
|
||||||
return audio_data.tobytes()
|
return audio_data.tobytes()
|
||||||
|
elif self.format == "mp3":
|
||||||
|
# For MP3, we need to handle streaming differently
|
||||||
|
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
||||||
|
frame.sample_rate = self.sample_rate
|
||||||
|
frame.pts = self.pts
|
||||||
|
self.pts += frame.samples
|
||||||
|
|
||||||
|
packets = self.stream.encode(frame)
|
||||||
|
for packet in packets:
|
||||||
|
self.container.mux(packet)
|
||||||
|
|
||||||
|
# For MP3, just return an empty byte string to indicate processing occurred
|
||||||
|
# The complete MP3 file will be returned during finalization
|
||||||
|
return b''
|
||||||
else:
|
else:
|
||||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
||||||
frame.sample_rate=self.sample_rate
|
frame.sample_rate=self.sample_rate
|
||||||
|
@ -80,4 +98,3 @@ class StreamingAudioWriter:
|
||||||
self.output_buffer.seek(0)
|
self.output_buffer.seek(0)
|
||||||
self.output_buffer.truncate(0)
|
self.output_buffer.truncate(0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -245,6 +245,12 @@ class TTSService:
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
current_offset = 0.0
|
current_offset = 0.0
|
||||||
|
|
||||||
|
# For MP3 format, we'll collect all audio chunks and encode at the end
|
||||||
|
is_mp3_format = output_format.lower() == "mp3"
|
||||||
|
all_audio_chunks = []
|
||||||
|
all_word_timestamps = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get backend
|
# Get backend
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
|
@ -268,13 +274,15 @@ class TTSService:
|
||||||
voice_path, # Pass voice path
|
voice_path, # Pass voice path
|
||||||
speed,
|
speed,
|
||||||
writer,
|
writer,
|
||||||
output_format,
|
# For MP3, we don't pass the output format during processing
|
||||||
|
None if is_mp3_format else output_format,
|
||||||
is_first=(chunk_index == 0),
|
is_first=(chunk_index == 0),
|
||||||
is_last=False, # We'll update the last chunk later
|
is_last=False, # We'll update the last chunk later
|
||||||
normalizer=stream_normalizer,
|
normalizer=stream_normalizer,
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps,
|
||||||
):
|
):
|
||||||
|
# Update timestamps with current offset
|
||||||
if chunk_data.word_timestamps is not None:
|
if chunk_data.word_timestamps is not None:
|
||||||
for timestamp in chunk_data.word_timestamps:
|
for timestamp in chunk_data.word_timestamps:
|
||||||
timestamp.start_time += current_offset
|
timestamp.start_time += current_offset
|
||||||
|
@ -282,18 +290,55 @@ class TTSService:
|
||||||
|
|
||||||
current_offset += len(chunk_data.audio) / 24000
|
current_offset += len(chunk_data.audio) / 24000
|
||||||
|
|
||||||
|
if is_mp3_format:
|
||||||
|
# Collect audio for MP3 final encoding
|
||||||
|
if len(chunk_data.audio) > 0:
|
||||||
|
all_audio_chunks.append(chunk_data.audio)
|
||||||
|
if chunk_data.word_timestamps:
|
||||||
|
all_word_timestamps.extend(chunk_data.word_timestamps)
|
||||||
|
else:
|
||||||
|
# For other formats, stream normally
|
||||||
if chunk_data.output is not None:
|
if chunk_data.output is not None:
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only finalize if we successfully processed at least one chunk
|
# Only finalize if we successfully processed at least one chunk
|
||||||
if chunk_index > 0:
|
if chunk_index > 0:
|
||||||
|
if is_mp3_format:
|
||||||
|
# For MP3, create a combined audio file
|
||||||
|
if all_audio_chunks:
|
||||||
|
try:
|
||||||
|
# Combine all audio chunks
|
||||||
|
combined_audio = np.concatenate(all_audio_chunks)
|
||||||
|
|
||||||
|
# Create a fresh MP3 writer for final encoding
|
||||||
|
mp3_writer = StreamingAudioWriter("mp3", sample_rate=settings.sample_rate)
|
||||||
|
|
||||||
|
# Write all audio at once and finalize
|
||||||
|
mp3_writer.write_chunk(combined_audio)
|
||||||
|
mp3_data = mp3_writer.write_chunk(finalize=True)
|
||||||
|
mp3_writer.close()
|
||||||
|
|
||||||
|
# Create the final chunk with all audio and timestamps
|
||||||
|
result_chunk = AudioChunk(
|
||||||
|
audio=combined_audio,
|
||||||
|
word_timestamps=all_word_timestamps,
|
||||||
|
output=mp3_data
|
||||||
|
)
|
||||||
|
yield result_chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create final MP3 file: {str(e)}")
|
||||||
|
raise RuntimeError(f"MP3 encoding failed: {str(e)}")
|
||||||
|
else:
|
||||||
|
# For other formats, just finalize the stream
|
||||||
try:
|
try:
|
||||||
# Empty tokens list to finalize audio
|
# Empty tokens list to finalize audio
|
||||||
async for chunk_data in self._process_chunk(
|
async for chunk_data in self._process_chunk(
|
||||||
|
@ -315,9 +360,8 @@ class TTSService:
|
||||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
|
|
Loading…
Add table
Reference in a new issue