mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Aculy fixed tests this time
This commit is contained in:
parent
c902b2ca0d
commit
c24aeefbb2
9 changed files with 114 additions and 68 deletions
|
@ -120,7 +120,7 @@ async def generate_from_phonemes(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
# Clean up writer on error
|
# Clean up writer on error
|
||||||
writer.write_chunk(finalize=True)
|
writer.close()
|
||||||
# Re-raise the original exception
|
# Re-raise the original exception
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -236,6 +236,7 @@ async def create_captioned_speech(
|
||||||
# Ensure temp writer is closed
|
# Ensure temp writer is closed
|
||||||
if not temp_writer._finalized:
|
if not temp_writer._finalized:
|
||||||
await temp_writer.__aexit__(None, None, None)
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
# Stream with temp file writing
|
# Stream with temp file writing
|
||||||
return JSONStreamingResponse(
|
return JSONStreamingResponse(
|
||||||
|
@ -267,6 +268,7 @@ async def create_captioned_speech(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in single output streaming: {e}")
|
logger.error(f"Error in single output streaming: {e}")
|
||||||
|
writer.close()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Standard streaming without download link
|
# Standard streaming without download link
|
||||||
|
@ -294,10 +296,8 @@ async def create_captioned_speech(
|
||||||
|
|
||||||
audio_data = await AudioService.convert_audio(
|
audio_data = await AudioService.convert_audio(
|
||||||
audio_data,
|
audio_data,
|
||||||
24000,
|
|
||||||
request.response_format,
|
request.response_format,
|
||||||
writer,
|
writer,
|
||||||
is_first_chunk=True,
|
|
||||||
is_last_chunk=False,
|
is_last_chunk=False,
|
||||||
trim_audio=False,
|
trim_audio=False,
|
||||||
)
|
)
|
||||||
|
@ -305,10 +305,8 @@ async def create_captioned_speech(
|
||||||
# Convert to requested format with proper finalization
|
# Convert to requested format with proper finalization
|
||||||
final = await AudioService.convert_audio(
|
final = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([], dtype=np.int16)),
|
AudioChunk(np.array([], dtype=np.int16)),
|
||||||
24000,
|
|
||||||
request.response_format,
|
request.response_format,
|
||||||
writer,
|
writer,
|
||||||
is_first_chunk=False,
|
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
)
|
)
|
||||||
output=audio_data.output + final.output
|
output=audio_data.output + final.output
|
||||||
|
@ -316,6 +314,9 @@ async def create_captioned_speech(
|
||||||
base64_output= base64.b64encode(output).decode("utf-8")
|
base64_output= base64.b64encode(output).decode("utf-8")
|
||||||
|
|
||||||
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=content,
|
content=content,
|
||||||
media_type="application/json",
|
media_type="application/json",
|
||||||
|
@ -328,6 +329,12 @@ async def create_captioned_speech(
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Handle validation errors
|
# Handle validation errors
|
||||||
logger.warning(f"Invalid request: {str(e)}")
|
logger.warning(f"Invalid request: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
|
@ -339,6 +346,12 @@ async def create_captioned_speech(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Handle runtime/processing errors
|
# Handle runtime/processing errors
|
||||||
logger.error(f"Processing error: {str(e)}")
|
logger.error(f"Processing error: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
@ -350,6 +363,12 @@ async def create_captioned_speech(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle unexpected errors
|
# Handle unexpected errors
|
||||||
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
|
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
|
|
@ -240,6 +240,7 @@ async def create_speech(
|
||||||
# Ensure temp writer is closed
|
# Ensure temp writer is closed
|
||||||
if not temp_writer._finalized:
|
if not temp_writer._finalized:
|
||||||
await temp_writer.__aexit__(None, None, None)
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
# Stream with temp file writing
|
# Stream with temp file writing
|
||||||
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
||||||
|
@ -252,6 +253,7 @@ async def create_speech(
|
||||||
yield chunk_data.output
|
yield chunk_data.output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in single output streaming: {e}")
|
logger.error(f"Error in single output streaming: {e}")
|
||||||
|
writer.close()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Standard streaming without download link
|
# Standard streaming without download link
|
||||||
|
@ -281,15 +283,13 @@ async def create_speech(
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False)
|
audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False)
|
||||||
|
|
||||||
# Convert to requested format with proper finalization
|
# Convert to requested format with proper finalization
|
||||||
final = await AudioService.convert_audio(
|
final = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([], dtype=np.int16)),
|
AudioChunk(np.array([], dtype=np.int16)),
|
||||||
24000,
|
|
||||||
request.response_format,
|
request.response_format,
|
||||||
writer,
|
writer,
|
||||||
is_first_chunk=False,
|
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
)
|
)
|
||||||
output = audio_data.output + final.output
|
output = audio_data.output + final.output
|
||||||
|
@ -321,6 +321,7 @@ async def create_speech(
|
||||||
# Ensure temp writer is closed
|
# Ensure temp writer is closed
|
||||||
if not temp_writer._finalized:
|
if not temp_writer._finalized:
|
||||||
await temp_writer.__aexit__(None, None, None)
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
content=output,
|
content=output,
|
||||||
|
@ -331,6 +332,12 @@ async def create_speech(
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Handle validation errors
|
# Handle validation errors
|
||||||
logger.warning(f"Invalid request: {str(e)}")
|
logger.warning(f"Invalid request: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
|
@ -342,6 +349,12 @@ async def create_speech(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Handle runtime/processing errors
|
# Handle runtime/processing errors
|
||||||
logger.error(f"Processing error: {str(e)}")
|
logger.error(f"Processing error: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
@ -353,6 +366,12 @@ async def create_speech(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle unexpected errors
|
# Handle unexpected errors
|
||||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
@ -363,6 +382,7 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/download/{filename}")
|
@router.get("/download/{filename}")
|
||||||
async def download_audio_file(filename: str):
|
async def download_audio_file(filename: str):
|
||||||
"""Download a generated audio file from temp storage"""
|
"""Download a generated audio file from temp storage"""
|
||||||
|
|
|
@ -108,17 +108,13 @@ class AudioService:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_writers = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def convert_audio(
|
async def convert_audio(
|
||||||
audio_chunk: AudioChunk,
|
audio_chunk: AudioChunk,
|
||||||
sample_rate: int,
|
|
||||||
output_format: str,
|
output_format: str,
|
||||||
writer: StreamingAudioWriter,
|
writer: StreamingAudioWriter,
|
||||||
speed: float = 1,
|
speed: float = 1,
|
||||||
chunk_text: str = "",
|
chunk_text: str = "",
|
||||||
is_first_chunk: bool = True,
|
|
||||||
is_last_chunk: bool = False,
|
is_last_chunk: bool = False,
|
||||||
trim_audio: bool = True,
|
trim_audio: bool = True,
|
||||||
normalizer: AudioNormalizer = None,
|
normalizer: AudioNormalizer = None,
|
||||||
|
@ -127,12 +123,12 @@ class AudioService:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_data: Numpy array of audio samples
|
audio_data: Numpy array of audio samples
|
||||||
sample_rate: Sample rate of the audio
|
|
||||||
output_format: Target format (wav, mp3, ogg, pcm)
|
output_format: Target format (wav, mp3, ogg, pcm)
|
||||||
|
writer: The StreamingAudioWriter to use
|
||||||
speed: The speaking speed of the voice
|
speed: The speaking speed of the voice
|
||||||
chunk_text: The text sent to the model to generate the resulting speech
|
chunk_text: The text sent to the model to generate the resulting speech
|
||||||
is_first_chunk: Whether this is the first chunk
|
|
||||||
is_last_chunk: Whether this is the last chunk
|
is_last_chunk: Whether this is the last chunk
|
||||||
|
trim_audio: Whether audio should be trimmed
|
||||||
normalizer: Optional AudioNormalizer instance for consistent normalization
|
normalizer: Optional AudioNormalizer instance for consistent normalization
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -153,15 +149,6 @@ class AudioService:
|
||||||
if trim_audio == True:
|
if trim_audio == True:
|
||||||
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
||||||
|
|
||||||
# Get or create format-specific writer
|
|
||||||
"""writer_key = f"{output_format}_{sample_rate}"
|
|
||||||
if is_first_chunk or writer_key not in AudioService._writers:
|
|
||||||
AudioService._writers[writer_key] = StreamingAudioWriter(
|
|
||||||
output_format, sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
writer = AudioService._writers[writer_key]"""
|
|
||||||
|
|
||||||
# Write audio data first
|
# Write audio data first
|
||||||
if len(audio_chunk.audio) > 0:
|
if len(audio_chunk.audio) > 0:
|
||||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format in ["wav", "opus","flac","mp3","aac","pcm"]:
|
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
||||||
|
@ -31,6 +31,13 @@ class StreamingAudioWriter:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if hasattr(self, "container"):
|
||||||
|
self.container.close()
|
||||||
|
|
||||||
|
if hasattr(self, "output_buffer"):
|
||||||
|
self.output_buffer.close()
|
||||||
|
|
||||||
def write_chunk(
|
def write_chunk(
|
||||||
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
|
@ -48,7 +55,7 @@ class StreamingAudioWriter:
|
||||||
self.container.mux(packet)
|
self.container.mux(packet)
|
||||||
|
|
||||||
data=self.output_buffer.getvalue()
|
data=self.output_buffer.getvalue()
|
||||||
self.container.close()
|
self.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
if audio_data is None or len(audio_data) == 0:
|
||||||
|
|
|
@ -70,12 +70,10 @@ class TTSService:
|
||||||
return
|
return
|
||||||
chunk_data = await AudioService.convert_audio(
|
chunk_data = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
||||||
24000,
|
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
"",
|
"",
|
||||||
is_first_chunk=False,
|
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
)
|
)
|
||||||
|
@ -105,12 +103,10 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
chunk_data = await AudioService.convert_audio(
|
chunk_data = await AudioService.convert_audio(
|
||||||
chunk_data,
|
chunk_data,
|
||||||
24000,
|
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
chunk_text,
|
chunk_text,
|
||||||
is_first_chunk=is_first and chunk_index == 0,
|
|
||||||
is_last_chunk=is_last,
|
is_last_chunk=is_last,
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
)
|
)
|
||||||
|
@ -139,12 +135,10 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
chunk_data = await AudioService.convert_audio(
|
chunk_data = await AudioService.convert_audio(
|
||||||
chunk_data,
|
chunk_data,
|
||||||
24000,
|
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
chunk_text,
|
chunk_text,
|
||||||
is_first_chunk=is_first,
|
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=is_last,
|
is_last_chunk=is_last,
|
||||||
)
|
)
|
||||||
|
|
|
@ -70,16 +70,3 @@ def test_voice():
|
||||||
"""Return a test voice name."""
|
"""Return a test voice name."""
|
||||||
return "voice1"
|
return "voice1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def event_loop():
|
|
||||||
"""Create an instance of the default event loop for the test session."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
|
@ -34,8 +34,11 @@ async def test_convert_to_wav(sample_audio):
|
||||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||||
# Write and finalize in one step for WAV
|
# Write and finalize in one step for WAV
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False
|
AudioChunk(audio_data), "wav", writer, is_last_chunk=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -52,8 +55,11 @@ async def test_convert_to_mp3(sample_audio):
|
||||||
writer = StreamingAudioWriter("mp3", sample_rate=24000)
|
writer = StreamingAudioWriter("mp3", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "mp3", writer
|
AudioChunk(audio_data), "mp3", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -64,13 +70,17 @@ async def test_convert_to_mp3(sample_audio):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_convert_to_opus(sample_audio):
|
async def test_convert_to_opus(sample_audio):
|
||||||
"""Test converting to Opus format"""
|
"""Test converting to Opus format"""
|
||||||
|
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
|
|
||||||
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "opus",writer
|
AudioChunk(audio_data), "opus",writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -86,8 +96,11 @@ async def test_convert_to_flac(sample_audio):
|
||||||
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "flac", writer
|
AudioChunk(audio_data), "flac", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -103,8 +116,11 @@ async def test_convert_to_aac(sample_audio):
|
||||||
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "aac", writer
|
AudioChunk(audio_data), "aac", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -120,8 +136,11 @@ async def test_convert_to_pcm(sample_audio):
|
||||||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "pcm", writer
|
AudioChunk(audio_data), "pcm", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -131,12 +150,9 @@ async def test_convert_to_pcm(sample_audio):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
"""Test that converting to an invalid format raises an error"""
|
"""Test that converting to an invalid format raises an error"""
|
||||||
audio_data, sample_rate = sample_audio
|
#audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
||||||
StreamingAudioWriter("invalid", sample_rate=24000)
|
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
|
||||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid", None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -150,8 +166,11 @@ async def test_normalization_wav(sample_audio):
|
||||||
large_audio = audio_data * 1e5
|
large_audio = audio_data * 1e5
|
||||||
# Write and finalize in one step for WAV
|
# Write and finalize in one step for WAV
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True
|
AudioChunk(large_audio), "wav", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -167,7 +186,7 @@ async def test_normalization_pcm(sample_audio):
|
||||||
# Create audio data outside int16 range
|
# Create audio data outside int16 range
|
||||||
large_audio = audio_data * 1e5
|
large_audio = audio_data * 1e5
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(large_audio), sample_rate, "pcm", writer
|
AudioChunk(large_audio), "pcm", writer
|
||||||
)
|
)
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
|
@ -197,8 +216,11 @@ async def test_different_sample_rates(sample_audio):
|
||||||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True
|
AudioChunk(audio_data), "wav", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert isinstance(audio_chunk.output, bytes)
|
assert isinstance(audio_chunk.output, bytes)
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
|
@ -213,7 +235,7 @@ async def test_buffer_position_after_conversion(sample_audio):
|
||||||
|
|
||||||
# Write and finalize in one step for first conversion
|
# Write and finalize in one step for first conversion
|
||||||
audio_chunk1 = await AudioService.convert_audio(
|
audio_chunk1 = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||||
)
|
)
|
||||||
assert isinstance(audio_chunk1.output, bytes)
|
assert isinstance(audio_chunk1.output, bytes)
|
||||||
assert isinstance(audio_chunk1, AudioChunk)
|
assert isinstance(audio_chunk1, AudioChunk)
|
||||||
|
@ -222,7 +244,7 @@ async def test_buffer_position_after_conversion(sample_audio):
|
||||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk2 = await AudioService.convert_audio(
|
audio_chunk2 = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||||
)
|
)
|
||||||
assert isinstance(audio_chunk2.output, bytes)
|
assert isinstance(audio_chunk2.output, bytes)
|
||||||
assert isinstance(audio_chunk2, AudioChunk)
|
assert isinstance(audio_chunk2, AudioChunk)
|
||||||
|
|
|
@ -4,6 +4,8 @@ import os
|
||||||
from typing import AsyncGenerator, Tuple
|
from typing import AsyncGenerator, Tuple
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||||
|
|
||||||
from api.src.inference.base import AudioChunk
|
from api.src.inference.base import AudioChunk
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -159,10 +161,14 @@ async def test_stream_audio_chunks_client_disconnect():
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer = StreamingAudioWriter("mp3", 24000)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
|
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
||||||
|
|
||||||
|
|
||||||
|
@ -483,7 +489,11 @@ async def test_streaming_initialization_error():
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer = StreamingAudioWriter("mp3", 24000)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as exc:
|
with pytest.raises(RuntimeError) as exc:
|
||||||
async for _ in stream_audio_chunks(mock_service, request, MagicMock()):
|
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
writer.close()
|
||||||
assert "Failed to initialize stream" in str(exc.value)
|
assert "Failed to initialize stream" in str(exc.value)
|
||||||
|
|
|
@ -37,7 +37,7 @@ dependencies = [
|
||||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
|
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
|
||||||
"inflect>=7.5.0",
|
"inflect>=7.5.0",
|
||||||
"phonemizer-fork>=3.3.2",
|
"phonemizer-fork>=3.3.2",
|
||||||
"av>=14.1.0",
|
"av>=14.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
@ -48,10 +48,10 @@ cpu = [
|
||||||
"torch==2.6.0",
|
"torch==2.6.0",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest==8.0.0",
|
"pytest==8.3.5",
|
||||||
"pytest-cov==4.1.0",
|
"pytest-cov==6.0.0",
|
||||||
"httpx==0.26.0",
|
"httpx==0.26.0",
|
||||||
"pytest-asyncio==0.23.5",
|
"pytest-asyncio==0.25.3",
|
||||||
"openai>=1.59.6",
|
"openai>=1.59.6",
|
||||||
"tomli>=2.0.1",
|
"tomli>=2.0.1",
|
||||||
]
|
]
|
||||||
|
@ -91,5 +91,5 @@ packages.find = {where = ["api/src"], namespaces = true}
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["api/tests", "ui/tests"]
|
testpaths = ["api/tests", "ui/tests"]
|
||||||
python_files = ["test_*.py"]
|
python_files = ["test_*.py"]
|
||||||
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
|
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace"
|
||||||
asyncio_mode = "strict"
|
asyncio_mode = "auto"
|
||||||
|
|
Loading…
Add table
Reference in a new issue