mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Aculy fixed tests this time
This commit is contained in:
parent
c902b2ca0d
commit
c24aeefbb2
9 changed files with 114 additions and 68 deletions
|
@ -120,7 +120,7 @@ async def generate_from_phonemes(
|
|||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
writer.close()
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
|
||||
|
@ -236,6 +236,7 @@ async def create_captioned_speech(
|
|||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
writer.close()
|
||||
|
||||
# Stream with temp file writing
|
||||
return JSONStreamingResponse(
|
||||
|
@ -267,6 +268,7 @@ async def create_captioned_speech(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
writer.close()
|
||||
raise
|
||||
|
||||
# Standard streaming without download link
|
||||
|
@ -294,10 +296,8 @@ async def create_captioned_speech(
|
|||
|
||||
audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=False,
|
||||
trim_audio=False,
|
||||
)
|
||||
|
@ -305,10 +305,8 @@ async def create_captioned_speech(
|
|||
# Convert to requested format with proper finalization
|
||||
final = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output=audio_data.output + final.output
|
||||
|
@ -316,6 +314,9 @@ async def create_captioned_speech(
|
|||
base64_output= base64.b64encode(output).decode("utf-8")
|
||||
|
||||
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
||||
|
||||
writer.close()
|
||||
|
||||
return JSONResponse(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
|
@ -328,6 +329,12 @@ async def create_captioned_speech(
|
|||
except ValueError as e:
|
||||
# Handle validation errors
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
@ -339,6 +346,12 @@ async def create_captioned_speech(
|
|||
except RuntimeError as e:
|
||||
# Handle runtime/processing errors
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
@ -350,6 +363,12 @@ async def create_captioned_speech(
|
|||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
|
|
@ -240,6 +240,7 @@ async def create_speech(
|
|||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
writer.close()
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
||||
|
@ -252,6 +253,7 @@ async def create_speech(
|
|||
yield chunk_data.output
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
writer.close()
|
||||
raise
|
||||
|
||||
# Standard streaming without download link
|
||||
|
@ -281,15 +283,13 @@ async def create_speech(
|
|||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
audio_data = await AudioService.convert_audio(audio_data, 24000, request.response_format, writer, is_first_chunk=True, is_last_chunk=False, trim_audio=False)
|
||||
audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
final = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
writer,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output = audio_data.output + final.output
|
||||
|
@ -321,6 +321,7 @@ async def create_speech(
|
|||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
writer.close()
|
||||
|
||||
return Response(
|
||||
content=output,
|
||||
|
@ -331,6 +332,12 @@ async def create_speech(
|
|||
except ValueError as e:
|
||||
# Handle validation errors
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
@ -342,6 +349,12 @@ async def create_speech(
|
|||
except RuntimeError as e:
|
||||
# Handle runtime/processing errors
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
@ -353,6 +366,12 @@ async def create_speech(
|
|||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
|
||||
try:
|
||||
writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
@ -361,6 +380,7 @@ async def create_speech(
|
|||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.get("/download/{filename}")
|
||||
|
|
|
@ -108,17 +108,13 @@ class AudioService:
|
|||
},
|
||||
}
|
||||
|
||||
_writers = {}
|
||||
|
||||
@staticmethod
|
||||
async def convert_audio(
|
||||
audio_chunk: AudioChunk,
|
||||
sample_rate: int,
|
||||
output_format: str,
|
||||
writer: StreamingAudioWriter,
|
||||
speed: float = 1,
|
||||
chunk_text: str = "",
|
||||
is_first_chunk: bool = True,
|
||||
is_last_chunk: bool = False,
|
||||
trim_audio: bool = True,
|
||||
normalizer: AudioNormalizer = None,
|
||||
|
@ -127,12 +123,12 @@ class AudioService:
|
|||
|
||||
Args:
|
||||
audio_data: Numpy array of audio samples
|
||||
sample_rate: Sample rate of the audio
|
||||
output_format: Target format (wav, mp3, ogg, pcm)
|
||||
writer: The StreamingAudioWriter to use
|
||||
speed: The speaking speed of the voice
|
||||
chunk_text: The text sent to the model to generate the resulting speech
|
||||
is_first_chunk: Whether this is the first chunk
|
||||
is_last_chunk: Whether this is the last chunk
|
||||
trim_audio: Whether audio should be trimmed
|
||||
normalizer: Optional AudioNormalizer instance for consistent normalization
|
||||
|
||||
Returns:
|
||||
|
@ -153,15 +149,6 @@ class AudioService:
|
|||
if trim_audio == True:
|
||||
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
||||
|
||||
# Get or create format-specific writer
|
||||
"""writer_key = f"{output_format}_{sample_rate}"
|
||||
if is_first_chunk or writer_key not in AudioService._writers:
|
||||
AudioService._writers[writer_key] = StreamingAudioWriter(
|
||||
output_format, sample_rate
|
||||
)
|
||||
|
||||
writer = AudioService._writers[writer_key]"""
|
||||
|
||||
# Write audio data first
|
||||
if len(audio_chunk.audio) > 0:
|
||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||
|
|
|
@ -22,7 +22,7 @@ class StreamingAudioWriter:
|
|||
|
||||
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
||||
# Format-specific setup
|
||||
if self.format in ["wav", "opus","flac","mp3","aac","pcm"]:
|
||||
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
||||
if self.format != "pcm":
|
||||
self.output_buffer = BytesIO()
|
||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
||||
|
@ -31,6 +31,13 @@ class StreamingAudioWriter:
|
|||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, "container"):
|
||||
self.container.close()
|
||||
|
||||
if hasattr(self, "output_buffer"):
|
||||
self.output_buffer.close()
|
||||
|
||||
def write_chunk(
|
||||
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
||||
) -> bytes:
|
||||
|
@ -48,7 +55,7 @@ class StreamingAudioWriter:
|
|||
self.container.mux(packet)
|
||||
|
||||
data=self.output_buffer.getvalue()
|
||||
self.container.close()
|
||||
self.close()
|
||||
return data
|
||||
|
||||
if audio_data is None or len(audio_data) == 0:
|
||||
|
|
|
@ -70,12 +70,10 @@ class TTSService:
|
|||
return
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
"",
|
||||
is_first_chunk=False,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
|
@ -105,12 +103,10 @@ class TTSService:
|
|||
try:
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
is_first_chunk=is_first and chunk_index == 0,
|
||||
is_last_chunk=is_last,
|
||||
normalizer=normalizer,
|
||||
)
|
||||
|
@ -139,12 +135,10 @@ class TTSService:
|
|||
try:
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
)
|
||||
|
|
|
@ -70,16 +70,3 @@ def test_voice():
|
|||
"""Return a test voice name."""
|
||||
return "voice1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
|
|
|
@ -34,8 +34,11 @@ async def test_convert_to_wav(sample_audio):
|
|||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=False
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=False
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -52,8 +55,11 @@ async def test_convert_to_mp3(sample_audio):
|
|||
writer = StreamingAudioWriter("mp3", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "mp3", writer
|
||||
AudioChunk(audio_data), "mp3", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -64,13 +70,17 @@ async def test_convert_to_mp3(sample_audio):
|
|||
@pytest.mark.asyncio
|
||||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "opus",writer
|
||||
AudioChunk(audio_data), "opus",writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -86,8 +96,11 @@ async def test_convert_to_flac(sample_audio):
|
|||
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "flac", writer
|
||||
AudioChunk(audio_data), "flac", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -103,8 +116,11 @@ async def test_convert_to_aac(sample_audio):
|
|||
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "aac", writer
|
||||
AudioChunk(audio_data), "aac", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -120,8 +136,11 @@ async def test_convert_to_pcm(sample_audio):
|
|||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "pcm", writer
|
||||
AudioChunk(audio_data), "pcm", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -131,12 +150,9 @@ async def test_convert_to_pcm(sample_audio):
|
|||
@pytest.mark.asyncio
|
||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
"""Test that converting to an invalid format raises an error"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
#audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
||||
StreamingAudioWriter("invalid", sample_rate=24000)
|
||||
|
||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid", None)
|
||||
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -150,8 +166,11 @@ async def test_normalization_wav(sample_audio):
|
|||
large_audio = audio_data * 1e5
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "wav", writer, is_first_chunk=True
|
||||
AudioChunk(large_audio), "wav", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -167,7 +186,7 @@ async def test_normalization_pcm(sample_audio):
|
|||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), sample_rate, "pcm", writer
|
||||
AudioChunk(large_audio), "pcm", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
|
@ -197,8 +216,11 @@ async def test_different_sample_rates(sample_audio):
|
|||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), rate, "wav", writer, is_first_chunk=True
|
||||
AudioChunk(audio_data), "wav", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
@ -213,7 +235,7 @@ async def test_buffer_position_after_conversion(sample_audio):
|
|||
|
||||
# Write and finalize in one step for first conversion
|
||||
audio_chunk1 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk1.output, bytes)
|
||||
assert isinstance(audio_chunk1, AudioChunk)
|
||||
|
@ -222,7 +244,7 @@ async def test_buffer_position_after_conversion(sample_audio):
|
|||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
audio_chunk2 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), sample_rate, "wav", writer, is_first_chunk=True, is_last_chunk=True
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk2.output, bytes)
|
||||
assert isinstance(audio_chunk2, AudioChunk)
|
||||
|
|
|
@ -4,6 +4,8 @@ import os
|
|||
from typing import AsyncGenerator, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||
|
||||
from api.src.inference.base import AudioChunk
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -159,10 +161,14 @@ async def test_stream_audio_chunks_client_disconnect():
|
|||
speed=1.0,
|
||||
)
|
||||
|
||||
writer = StreamingAudioWriter("mp3", 24000)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
|
||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
|
||||
chunks.append(chunk)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
||||
|
||||
|
||||
|
@ -483,7 +489,11 @@ async def test_streaming_initialization_error():
|
|||
speed=1.0,
|
||||
)
|
||||
|
||||
writer = StreamingAudioWriter("mp3", 24000)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
async for _ in stream_audio_chunks(mock_service, request, MagicMock()):
|
||||
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
|
||||
pass
|
||||
|
||||
writer.close()
|
||||
assert "Failed to initialize stream" in str(exc.value)
|
||||
|
|
|
@ -37,7 +37,7 @@ dependencies = [
|
|||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
|
||||
"inflect>=7.5.0",
|
||||
"phonemizer-fork>=3.3.2",
|
||||
"av>=14.1.0",
|
||||
"av>=14.2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -48,10 +48,10 @@ cpu = [
|
|||
"torch==2.6.0",
|
||||
]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
"pytest-cov==4.1.0",
|
||||
"pytest==8.3.5",
|
||||
"pytest-cov==6.0.0",
|
||||
"httpx==0.26.0",
|
||||
"pytest-asyncio==0.23.5",
|
||||
"pytest-asyncio==0.25.3",
|
||||
"openai>=1.59.6",
|
||||
"tomli>=2.0.1",
|
||||
]
|
||||
|
@ -91,5 +91,5 @@ packages.find = {where = ["api/src"], namespaces = true}
|
|||
[tool.pytest.ini_options]
|
||||
testpaths = ["api/tests", "ui/tests"]
|
||||
python_files = ["test_*.py"]
|
||||
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
|
||||
asyncio_mode = "strict"
|
||||
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace"
|
||||
asyncio_mode = "auto"
|
||||
|
|
Loading…
Add table
Reference in a new issue