mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Simplify code so erverything uses AudioChunks
This commit is contained in:
parent
9c0e328318
commit
e3dc959775
7 changed files with 103 additions and 97 deletions
|
@ -11,10 +11,12 @@ class AudioChunk:
|
|||
|
||||
def __init__(self,
|
||||
audio: np.ndarray,
|
||||
word_timestamps: Optional[List]=[]
|
||||
word_timestamps: Optional[List]=[],
|
||||
output: Optional[Union[bytes,np.ndarray]]=b""
|
||||
):
|
||||
self.audio=audio
|
||||
self.word_timestamps=word_timestamps
|
||||
self.output=output
|
||||
|
||||
@staticmethod
|
||||
def combine(audio_chunk_list: List):
|
||||
|
|
|
@ -209,10 +209,10 @@ async def create_captioned_speech(
|
|||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||
async for chunk_data in generator:
|
||||
if chunk_data.output: # Skip empty chunks
|
||||
await temp_writer.write(chunk_data.output)
|
||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
||||
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
||||
|
||||
|
@ -235,10 +235,10 @@ async def create_captioned_speech(
|
|||
async def single_output():
|
||||
try:
|
||||
# Stream chunks
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
async for chunk_data in generator:
|
||||
if chunk_data.output: # Skip empty chunks
|
||||
# Encode the chunk bytes into base 64
|
||||
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
||||
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
||||
except Exception as e:
|
||||
|
@ -258,7 +258,7 @@ async def create_captioned_speech(
|
|||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
_, audio_data = await tts_service.generate_audio(
|
||||
audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
|
@ -267,7 +267,7 @@ async def create_captioned_speech(
|
|||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
content, audio_data = await AudioService.convert_audio(
|
||||
audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
|
@ -276,14 +276,14 @@ async def create_captioned_speech(
|
|||
)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
final, _ = await AudioService.convert_audio(
|
||||
final = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output=content+final
|
||||
output=content.output + final.output
|
||||
|
||||
base64_output= base64.b64encode(output).decode("utf-8")
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ async def process_voices(
|
|||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
|
@ -144,7 +144,7 @@ async def stream_audio_chunks(
|
|||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
||||
async for chunk_data in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
|
@ -162,7 +162,7 @@ async def stream_audio_chunks(
|
|||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
|
||||
yield (chunk,chunk_data)
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio streaming: {str(e)}")
|
||||
# Let the exception propagate to trigger cleanup
|
||||
|
@ -233,13 +233,13 @@ async def create_speech(
|
|||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
async for chunk_data in generator:
|
||||
if chunk_data.output: # Skip empty chunks
|
||||
await temp_writer.write(chunk_data.output)
|
||||
#if return_json:
|
||||
# yield chunk, chunk_data
|
||||
#else:
|
||||
yield chunk
|
||||
yield chunk_data.output
|
||||
|
||||
# Finalize the temp file
|
||||
await temp_writer.finalize()
|
||||
|
@ -260,9 +260,9 @@ async def create_speech(
|
|||
async def single_output():
|
||||
try:
|
||||
# Stream chunks
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
yield chunk
|
||||
async for chunk_data in generator:
|
||||
if chunk_data.output: # Skip empty chunks
|
||||
yield chunk_data.output
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
raise
|
||||
|
@ -280,7 +280,7 @@ async def create_speech(
|
|||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
_, audio_data = await tts_service.generate_audio(
|
||||
audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
|
@ -288,7 +288,7 @@ async def create_speech(
|
|||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
content, audio_data = await AudioService.convert_audio(
|
||||
audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
|
@ -297,14 +297,14 @@ async def create_speech(
|
|||
)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
final, _ = await AudioService.convert_audio(
|
||||
final = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output=content+final
|
||||
output=audio_data.output + final.output
|
||||
return Response(
|
||||
content=output,
|
||||
media_type=content_type,
|
||||
|
|
|
@ -120,7 +120,7 @@ class AudioService:
|
|||
is_first_chunk: bool = True,
|
||||
is_last_chunk: bool = False,
|
||||
normalizer: AudioNormalizer = None,
|
||||
) -> Tuple[bytes,AudioChunk]:
|
||||
) -> Tuple[AudioChunk]:
|
||||
"""Convert audio data to specified format with streaming support
|
||||
|
||||
Args:
|
||||
|
@ -165,9 +165,13 @@ class AudioService:
|
|||
if is_last_chunk:
|
||||
final_data = writer.write_chunk(finalize=True)
|
||||
del AudioService._writers[writer_key]
|
||||
return final_data if final_data else b"", audio_chunk
|
||||
|
||||
return chunk_data if chunk_data else b"", audio_chunk
|
||||
if final_data:
|
||||
audio_chunk.output=final_data
|
||||
return audio_chunk
|
||||
|
||||
if chunk_data:
|
||||
audio_chunk.output=chunk_data
|
||||
return audio_chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
||||
|
|
|
@ -54,7 +54,7 @@ class TTSService:
|
|||
normalizer: Optional[AudioNormalizer] = None,
|
||||
lang_code: Optional[str] = None,
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Process tokens into audio."""
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
|
@ -62,9 +62,9 @@ class TTSService:
|
|||
if is_last:
|
||||
# Skip format conversion for raw audio mode
|
||||
if not output_format:
|
||||
yield np.array([], dtype=np.int16), AudioChunk(np.array([], dtype=np.int16))
|
||||
yield AudioChunk(np.array([], dtype=np.int16),output=b'')
|
||||
return
|
||||
result, chunk_data = await AudioService.convert_audio(
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
|
||||
24000,
|
||||
output_format,
|
||||
|
@ -74,7 +74,7 @@ class TTSService:
|
|||
normalizer=normalizer,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
yield result, chunk_data
|
||||
yield chunk_data
|
||||
return
|
||||
|
||||
# Skip empty chunks
|
||||
|
@ -97,7 +97,7 @@ class TTSService:
|
|||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
converted, chunk_data = await AudioService.convert_audio(
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
|
@ -107,7 +107,7 @@ class TTSService:
|
|||
is_last_chunk=is_last,
|
||||
normalizer=normalizer,
|
||||
)
|
||||
yield converted, chunk_data
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
|
@ -116,7 +116,7 @@ class TTSService:
|
|||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
yield chunk_data.audio, chunk_data
|
||||
yield chunk_data
|
||||
else:
|
||||
|
||||
# For legacy backends, load voice tensor
|
||||
|
@ -138,7 +138,7 @@ class TTSService:
|
|||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
converted, chunk_data = await AudioService.convert_audio(
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
|
@ -148,7 +148,7 @@ class TTSService:
|
|||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
)
|
||||
yield converted, chunk_data
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
|
@ -157,7 +157,7 @@ class TTSService:
|
|||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
yield trimmed.audio, trimmed
|
||||
yield trimmed
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
|
||||
|
@ -243,7 +243,7 @@ class TTSService:
|
|||
lang_code: Optional[str] = None,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
|
@ -267,7 +267,7 @@ class TTSService:
|
|||
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for result, chunk_data in self._process_chunk(
|
||||
async for chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
|
@ -287,8 +287,8 @@ class TTSService:
|
|||
|
||||
current_offset+=len(chunk_data.audio) / 24000
|
||||
|
||||
if result is not None:
|
||||
yield result,chunk_data
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
chunk_index += 1
|
||||
else:
|
||||
logger.warning(
|
||||
|
@ -305,7 +305,7 @@ class TTSService:
|
|||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for result,chunk_data in self._process_chunk(
|
||||
async for chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
|
@ -317,8 +317,8 @@ class TTSService:
|
|||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if result is not None:
|
||||
yield result, chunk_data
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
|
||||
|
@ -335,17 +335,17 @@ class TTSService:
|
|||
return_timestamps: bool = False,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
lang_code: Optional[str] = None,
|
||||
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||
) -> AudioChunk:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
audio_data_chunks=[]
|
||||
|
||||
try:
|
||||
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data.audio,combined_audio_data
|
||||
return combined_audio_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
|
|
@ -31,83 +31,83 @@ async def test_convert_to_wav(sample_audio):
|
|||
"""Test converting to WAV format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Write and finalize in one step for WAV
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check WAV header
|
||||
assert result.startswith(b"RIFF")
|
||||
assert b"WAVE" in result[:12]
|
||||
assert audio_chunk.output.startswith(b"RIFF")
|
||||
assert b"WAVE" in audio_chunk.output[:12]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_mp3(sample_audio):
|
||||
"""Test converting to MP3 format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "mp3"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check MP3 header (ID3 or MPEG frame sync)
|
||||
assert result.startswith(b"ID3") or result.startswith(b"\xff\xfb")
|
||||
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "opus"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check OGG header
|
||||
assert result.startswith(b"OggS")
|
||||
assert audio_chunk.output.startswith(b"OggS")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_flac(sample_audio):
|
||||
"""Test converting to FLAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "flac"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check FLAC header
|
||||
assert result.startswith(b"fLaC")
|
||||
assert audio_chunk.output.startswith(b"fLaC")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_aac(sample_audio):
|
||||
"""Test converting to AAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "aac"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check ADTS header (AAC)
|
||||
assert result.startswith(b"\xff\xf0") or result.startswith(b"\xff\xf1")
|
||||
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "pcm"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
# PCM is raw bytes, so no header to check
|
||||
|
||||
|
||||
|
@ -126,12 +126,12 @@ async def test_normalization_wav(sample_audio):
|
|||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
# Write and finalize in one step for WAV
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -140,12 +140,12 @@ async def test_normalization_pcm(sample_audio):
|
|||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "pcm"
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -164,12 +164,12 @@ async def test_different_sample_rates(sample_audio):
|
|||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
result, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result) > 0
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -177,15 +177,15 @@ async def test_buffer_position_after_conversion(sample_audio):
|
|||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Write and finalize in one step for first conversion
|
||||
result1, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk1 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result1, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert isinstance(audio_chunk1.output, bytes)
|
||||
assert isinstance(audio_chunk1, AudioChunk)
|
||||
# Convert again to ensure buffer was properly reset
|
||||
result2, audio_chunk = await AudioService.convert_audio(
|
||||
audio_chunk2 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result2, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(result1) == len(result2)
|
||||
assert isinstance(audio_chunk2.output, bytes)
|
||||
assert isinstance(audio_chunk2, AudioChunk)
|
||||
assert len(audio_chunk1.output) == len(audio_chunk2.output)
|
||||
|
|
|
@ -145,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
|
|||
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for i in range(5):
|
||||
yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
|
||||
yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
|
||||
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
@ -160,7 +160,7 @@ async def test_stream_audio_chunks_client_disconnect():
|
|||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk, _ in stream_audio_chunks(mock_service, request, mock_request):
|
||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
||||
|
@ -237,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
|
|||
"""Mock TTS service for testing."""
|
||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||
service = AsyncMock(spec=TTSService)
|
||||
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||
service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
||||
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
||||
yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
|
||||
|
||||
service.generate_audio_stream = mock_stream
|
||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||
|
@ -257,8 +257,8 @@ def test_openai_speech_endpoint(
|
|||
):
|
||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||
# Configure mocks
|
||||
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||
mock_convert.return_value = (mock_audio_bytes,AudioChunk(np.zeros(1000,np.int16)))
|
||||
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
||||
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
|
|
Loading…
Add table
Reference in a new issue