mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
CONTRIBUTING + Ruff format
This commit is contained in:
parent
447f9d360c
commit
afa879546c
37 changed files with 1253 additions and 588 deletions
74
CONTRIBUTING.md
Normal file
74
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Contributing to Kokoro-FastAPI
|
||||||
|
|
||||||
|
Always appreciate community involvement in making this project better.
|
||||||
|
|
||||||
|
## Development Setup
|
||||||
|
|
||||||
|
We use `uv` for managing Python environments and dependencies, and `ruff` for linting and formatting.
|
||||||
|
|
||||||
|
1. **Clone the repository:**
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||||
|
cd Kokoro-FastAPI
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install `uv`:**
|
||||||
|
Follow the instructions on the [official `uv` documentation](https://docs.astral.sh/uv/install/).
|
||||||
|
|
||||||
|
3. **Create a virtual environment and install dependencies:**
|
||||||
|
It's recommended to use a virtual environment. `uv` can create one for you. Install the base dependencies along with the `test` and `cpu` extras (needed for running tests locally).
|
||||||
|
```bash
|
||||||
|
# Create and activate a virtual environment (e.g., named .venv)
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate # On Linux/macOS
|
||||||
|
# .venv\Scripts\activate # On Windows
|
||||||
|
|
||||||
|
# Install dependencies including test requirements
|
||||||
|
uv pip install -e ".[test,cpu]"
|
||||||
|
```
|
||||||
|
*Note: If you have an NVIDIA GPU and want to test GPU-specific features locally, you can install `.[test,gpu]` instead, ensuring you have the correct CUDA toolkit installed.*
|
||||||
|
|
||||||
|
*Note: If running via uv locally, you will have to install espeak and handle any pathing issues that arise. The Docker images handle this automatically*
|
||||||
|
|
||||||
|
4. **Install `ruff` (if not already installed globally):**
|
||||||
|
While `ruff` might be included via dependencies, installing it explicitly ensures you have it available.
|
||||||
|
```bash
|
||||||
|
uv pip install ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
Before submitting changes, please ensure all tests pass as this is a automated requirement. The tests are run using `pytest`.
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
*Note: The CI workflow runs tests using `uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing`. Running `uv run pytest` locally should cover the essential checks.*
|
||||||
|
|
||||||
|
## Code Formatting and Linting
|
||||||
|
|
||||||
|
We use `ruff` to maintain code quality and consistency. Please format and lint your code before committing.
|
||||||
|
|
||||||
|
1. **Format the code:**
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
ruff format .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Lint the code (and apply automatic fixes):**
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
ruff check . --fix
|
||||||
|
```
|
||||||
|
Review any changes made by `--fix` and address any remaining linting errors manually.
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
|
||||||
|
0. Clone the repo
|
||||||
|
1. Create a new branch for your feature or bug fix.
|
||||||
|
2. Make your changes, following setup, testing, and formatting guidelines above.
|
||||||
|
3. Please try to keep your changes inline with the current design, and modular. Large-scale changes will take longer to review and integrate, and have less chance of being approved outright.
|
||||||
|
4. Push your branch to your fork.
|
||||||
|
5. Open a Pull Request against the `master` branch of the main repository.
|
||||||
|
|
||||||
|
Thank you for contributing!
|
|
@ -14,9 +14,13 @@ class Settings(BaseSettings):
|
||||||
output_dir: str = "output"
|
output_dir: str = "output"
|
||||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||||
default_voice: str = "af_heart"
|
default_voice: str = "af_heart"
|
||||||
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
default_voice_code: str | None = (
|
||||||
|
None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||||
|
)
|
||||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||||
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
device_type: str | None = (
|
||||||
|
None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
||||||
|
)
|
||||||
allow_local_voice_saving: bool = (
|
allow_local_voice_saving: bool = (
|
||||||
False # Whether to allow saving combined voices locally
|
False # Whether to allow saving combined voices locally
|
||||||
)
|
)
|
||||||
|
@ -32,11 +36,20 @@ class Settings(BaseSettings):
|
||||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||||
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
||||||
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
|
voice_weight_normalization: bool = (
|
||||||
|
True # Normalize the voice weights so they add up to 1
|
||||||
|
)
|
||||||
|
|
||||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
gap_trim_ms: int = (
|
||||||
|
1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||||
|
)
|
||||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||||
dynamic_gap_trim_padding_char_multiplier: dict[str,float] = {".":1,"!":0.9,"?":1,",":0.8}
|
dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
|
||||||
|
".": 1,
|
||||||
|
"!": 0.9,
|
||||||
|
"?": 1,
|
||||||
|
",": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
# Web Player Settings
|
# Web Player Settings
|
||||||
enable_web_player: bool = True # Whether to serve the web player UI
|
enable_web_player: bool = True # Whether to serve the web player UI
|
||||||
|
@ -69,5 +82,4 @@ class Settings(BaseSettings):
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
@ -10,10 +10,11 @@ import torch
|
||||||
class AudioChunk:
|
class AudioChunk:
|
||||||
"""Class for audio chunks returned by model backends"""
|
"""Class for audio chunks returned by model backends"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
word_timestamps: Optional[List] = [],
|
word_timestamps: Optional[List] = [],
|
||||||
output: Optional[Union[bytes,np.ndarray]]=b""
|
output: Optional[Union[bytes, np.ndarray]] = b"",
|
||||||
):
|
):
|
||||||
self.audio = audio
|
self.audio = audio
|
||||||
self.word_timestamps = word_timestamps
|
self.word_timestamps = word_timestamps
|
||||||
|
@ -21,15 +22,20 @@ class AudioChunk:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine(audio_chunk_list: List):
|
def combine(audio_chunk_list: List):
|
||||||
output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
|
output = AudioChunk(
|
||||||
|
audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
for audio_chunk in audio_chunk_list[1:]:
|
for audio_chunk in audio_chunk_list[1:]:
|
||||||
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
|
output.audio = np.concatenate(
|
||||||
|
(output.audio, audio_chunk.audio), dtype=np.int16
|
||||||
|
)
|
||||||
if output.word_timestamps is not None:
|
if output.word_timestamps is not None:
|
||||||
output.word_timestamps += audio_chunk.word_timestamps
|
output.word_timestamps += audio_chunk.word_timestamps
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class ModelBackend(ABC):
|
class ModelBackend(ABC):
|
||||||
"""Abstract base class for model inference backend."""
|
"""Abstract base class for model inference backend."""
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,9 @@ class KokoroV1(BaseModelBackend):
|
||||||
self._model = KModel(config=config_path, model=model_path).eval()
|
self._model = KModel(config=config_path, model=model_path).eval()
|
||||||
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
||||||
if self._device == "mps":
|
if self._device == "mps":
|
||||||
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
|
logger.info(
|
||||||
|
"Moving model to MPS device with CPU fallback for unsupported operations"
|
||||||
|
)
|
||||||
self._model = self._model.to(torch.device("mps"))
|
self._model = self._model.to(torch.device("mps"))
|
||||||
elif self._device == "cuda":
|
elif self._device == "cuda":
|
||||||
self._model = self._model.cuda()
|
self._model = self._model.cuda()
|
||||||
|
@ -245,7 +247,15 @@ class KokoroV1(BaseModelBackend):
|
||||||
voice_path = temp_path
|
voice_path = temp_path
|
||||||
|
|
||||||
# Use provided lang_code, settings voice code override, or first letter of voice name
|
# Use provided lang_code, settings voice code override, or first letter of voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower())
|
pipeline_lang_code = (
|
||||||
|
lang_code
|
||||||
|
if lang_code
|
||||||
|
else (
|
||||||
|
settings.default_voice_code
|
||||||
|
if settings.default_voice_code
|
||||||
|
else voice_name[0].lower()
|
||||||
|
)
|
||||||
|
)
|
||||||
pipeline = self._get_pipeline(pipeline_lang_code)
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -257,7 +267,11 @@ class KokoroV1(BaseModelBackend):
|
||||||
if result.audio is not None:
|
if result.audio is not None:
|
||||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
word_timestamps = None
|
word_timestamps = None
|
||||||
if return_timestamps and hasattr(result, "tokens") and result.tokens:
|
if (
|
||||||
|
return_timestamps
|
||||||
|
and hasattr(result, "tokens")
|
||||||
|
and result.tokens
|
||||||
|
):
|
||||||
word_timestamps = []
|
word_timestamps = []
|
||||||
current_offset = 0.0
|
current_offset = 0.0
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -265,7 +279,6 @@ class KokoroV1(BaseModelBackend):
|
||||||
)
|
)
|
||||||
if result.pred_dur is not None:
|
if result.pred_dur is not None:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Add timestamps with offset
|
# Add timestamps with offset
|
||||||
for token in result.tokens:
|
for token in result.tokens:
|
||||||
if not all(
|
if not all(
|
||||||
|
@ -286,7 +299,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
WordTimestamp(
|
WordTimestamp(
|
||||||
word=str(token.text).strip(),
|
word=str(token.text).strip(),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -298,8 +311,9 @@ class KokoroV1(BaseModelBackend):
|
||||||
f"Failed to process timestamps for chunk: {e}"
|
f"Failed to process timestamps for chunk: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield AudioChunk(
|
||||||
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
|
result.audio.numpy(), word_timestamps=word_timestamps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("No audio in chunk")
|
logger.warning("No audio in chunk")
|
||||||
|
|
||||||
|
@ -330,7 +344,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif self._device == "mps":
|
elif self._device == "mps":
|
||||||
# Empty cache if available (future-proofing)
|
# Empty cache if available (future-proofing)
|
||||||
if hasattr(torch.mps, 'empty_cache'):
|
if hasattr(torch.mps, "empty_cache"):
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
|
|
|
@ -119,7 +119,7 @@ async def get_system_info():
|
||||||
"type": "MPS",
|
"type": "MPS",
|
||||||
"available": True,
|
"available": True,
|
||||||
"device": "Apple Silicon",
|
"device": "Apple Silicon",
|
||||||
"backend": "Metal"
|
"backend": "Metal",
|
||||||
}
|
}
|
||||||
elif GPU_AVAILABLE:
|
elif GPU_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -156,6 +156,7 @@ async def generate_from_phonemes(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/dev/captioned_speech")
|
@router.post("/dev/captioned_speech")
|
||||||
async def create_captioned_speech(
|
async def create_captioned_speech(
|
||||||
request: CaptionedSpeechRequest,
|
request: CaptionedSpeechRequest,
|
||||||
|
@ -184,7 +185,9 @@ async def create_captioned_speech(
|
||||||
# Check if streaming is requested (default for OpenAI client)
|
# Check if streaming is requested (default for OpenAI client)
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Create generator but don't start it yet
|
# Create generator but don't start it yet
|
||||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
generator = stream_audio_chunks(
|
||||||
|
tts_service, request, client_request, writer
|
||||||
|
)
|
||||||
|
|
||||||
# If download link requested, wrap generator with temp file writer
|
# If download link requested, wrap generator with temp file writer
|
||||||
if request.return_download_link:
|
if request.return_download_link:
|
||||||
|
@ -215,15 +218,26 @@ async def create_captioned_speech(
|
||||||
|
|
||||||
if chunk_data.output: # Skip empty chunks
|
if chunk_data.output: # Skip empty chunks
|
||||||
await temp_writer.write(chunk_data.output)
|
await temp_writer.write(chunk_data.output)
|
||||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
base64_chunk = base64.b64encode(
|
||||||
|
chunk_data.output
|
||||||
|
).decode("utf-8")
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
|
chunk_data.word_timestamps = (
|
||||||
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
|
)
|
||||||
timestamp_acumulator = []
|
timestamp_acumulator = []
|
||||||
|
|
||||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
yield CaptionedSpeechResponse(
|
||||||
|
audio=base64_chunk,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=chunk_data.word_timestamps,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
|
if (
|
||||||
|
chunk_data.word_timestamps is not None
|
||||||
|
and len(chunk_data.word_timestamps) > 0
|
||||||
|
):
|
||||||
timestamp_acumulator += chunk_data.word_timestamps
|
timestamp_acumulator += chunk_data.word_timestamps
|
||||||
|
|
||||||
# Finalize the temp file
|
# Finalize the temp file
|
||||||
|
@ -252,18 +266,29 @@ async def create_captioned_speech(
|
||||||
async for chunk_data in generator:
|
async for chunk_data in generator:
|
||||||
if chunk_data.output: # Skip empty chunks
|
if chunk_data.output: # Skip empty chunks
|
||||||
# Encode the chunk bytes into base 64
|
# Encode the chunk bytes into base 64
|
||||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
base64_chunk = base64.b64encode(chunk_data.output).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
if chunk_data.word_timestamps != None:
|
if chunk_data.word_timestamps != None:
|
||||||
chunk_data.word_timestamps = timestamp_acumulator + chunk_data.word_timestamps
|
chunk_data.word_timestamps = (
|
||||||
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chunk_data.word_timestamps = []
|
chunk_data.word_timestamps = []
|
||||||
timestamp_acumulator = []
|
timestamp_acumulator = []
|
||||||
|
|
||||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
yield CaptionedSpeechResponse(
|
||||||
|
audio=base64_chunk,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=chunk_data.word_timestamps,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
|
if (
|
||||||
|
chunk_data.word_timestamps is not None
|
||||||
|
and len(chunk_data.word_timestamps) > 0
|
||||||
|
):
|
||||||
timestamp_acumulator += chunk_data.word_timestamps
|
timestamp_acumulator += chunk_data.word_timestamps
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -313,7 +338,11 @@ async def create_captioned_speech(
|
||||||
|
|
||||||
base64_output = base64.b64encode(output).decode("utf-8")
|
base64_output = base64.b64encode(output).decode("utf-8")
|
||||||
|
|
||||||
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
content = CaptionedSpeechResponse(
|
||||||
|
audio=base64_output,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=audio_data.word_timestamps,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,9 @@ def get_model_name(model: str) -> str:
|
||||||
return base_name + ".pth"
|
return base_name + ".pth"
|
||||||
|
|
||||||
|
|
||||||
async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
async def process_and_validate_voices(
|
||||||
|
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||||
|
) -> str:
|
||||||
"""Process voice input, handling both string and list formats
|
"""Process voice input, handling both string and list formats
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -107,22 +109,35 @@ async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_se
|
||||||
mapped_voice = list(map(str.strip, mapped_voice))
|
mapped_voice = list(map(str.strip, mapped_voice))
|
||||||
|
|
||||||
if len(mapped_voice) > 2:
|
if len(mapped_voice) > 2:
|
||||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
|
)
|
||||||
|
|
||||||
if mapped_voice.count(")") > 1:
|
if mapped_voice.count(")") > 1:
|
||||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
|
)
|
||||||
|
|
||||||
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
mapped_voice[0] = _openai_mappings["voices"].get(
|
||||||
|
mapped_voice[0], mapped_voice[0]
|
||||||
|
)
|
||||||
|
|
||||||
if mapped_voice[0] not in available_voices:
|
if mapped_voice[0] not in available_voices:
|
||||||
raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
raise ValueError(
|
||||||
|
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
|
||||||
voices[voice_index] = "(".join(mapped_voice)
|
voices[voice_index] = "(".join(mapped_voice)
|
||||||
|
|
||||||
return "".join(voices)
|
return "".join(voices)
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
async def stream_audio_chunks(
|
||||||
|
tts_service: TTSService,
|
||||||
|
request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
|
||||||
|
client_request: Request,
|
||||||
|
writer: StreamingAudioWriter,
|
||||||
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
unique_properties = {"return_timestamps": False}
|
unique_properties = {"return_timestamps": False}
|
||||||
|
@ -193,7 +208,9 @@ async def create_speech(
|
||||||
# Check if streaming is requested (default for OpenAI client)
|
# Check if streaming is requested (default for OpenAI client)
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Create generator but don't start it yet
|
# Create generator but don't start it yet
|
||||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
generator = stream_audio_chunks(
|
||||||
|
tts_service, request, client_request, writer
|
||||||
|
)
|
||||||
|
|
||||||
# If download link requested, wrap generator with temp file writer
|
# If download link requested, wrap generator with temp file writer
|
||||||
if request.return_download_link:
|
if request.return_download_link:
|
||||||
|
@ -245,7 +262,9 @@ async def create_speech(
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# Stream with temp file writing
|
# Stream with temp file writing
|
||||||
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
return StreamingResponse(
|
||||||
|
dual_output(), media_type=content_type, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
async def single_output():
|
async def single_output():
|
||||||
try:
|
try:
|
||||||
|
@ -285,7 +304,13 @@ async def create_speech(
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False)
|
audio_data = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
request.response_format,
|
||||||
|
writer,
|
||||||
|
is_last_chunk=False,
|
||||||
|
trim_audio=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to requested format with proper finalization
|
# Convert to requested format with proper finalization
|
||||||
final = await AudioService.convert_audio(
|
final = await AudioService.convert_audio(
|
||||||
|
@ -384,7 +409,6 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/download/{filename}")
|
@router.get("/download/{filename}")
|
||||||
async def download_audio_file(filename: str):
|
async def download_audio_file(filename: str):
|
||||||
"""Download a generated audio file from temp storage"""
|
"""Download a generated audio file from temp storage"""
|
||||||
|
@ -392,7 +416,9 @@ async def download_audio_file(filename: str):
|
||||||
from ..core.paths import _find_file, get_content_type
|
from ..core.paths import _find_file, get_content_type
|
||||||
|
|
||||||
# Search for file in temp directory
|
# Search for file in temp directory
|
||||||
file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir])
|
file_path = await _find_file(
|
||||||
|
filename=filename, search_paths=[settings.temp_file_dir]
|
||||||
|
)
|
||||||
|
|
||||||
# Get content type from path helper
|
# Get content type from path helper
|
||||||
content_type = await get_content_type(file_path)
|
content_type = await get_content_type(file_path)
|
||||||
|
@ -425,9 +451,24 @@ async def list_models():
|
||||||
try:
|
try:
|
||||||
# Create standard model list
|
# Create standard model list
|
||||||
models = [
|
models = [
|
||||||
{"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
{
|
||||||
{"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"id": "tts-1",
|
||||||
{"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"object": "list", "data": models}
|
return {"object": "list", "data": models}
|
||||||
|
@ -449,14 +490,36 @@ async def retrieve_model(model: str):
|
||||||
try:
|
try:
|
||||||
# Define available models
|
# Define available models
|
||||||
models = {
|
models = {
|
||||||
"tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"tts-1": {
|
||||||
"tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"id": "tts-1",
|
||||||
"kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
"tts-1-hd": {
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
"kokoro": {
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check if requested model exists
|
# Check if requested model exists
|
||||||
if model not in models:
|
if model not in models:
|
||||||
raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"})
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": "model_not_found",
|
||||||
|
"message": f"Model '{model}' not found",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Return the specific model
|
# Return the specific model
|
||||||
return models[model]
|
return models[model]
|
||||||
|
@ -541,7 +604,9 @@ async def combine_voices(request: Union[str, List[str]]):
|
||||||
available_voices = await tts_service.list_voices()
|
available_voices = await tts_service.list_voices()
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
if voice not in available_voices:
|
if voice not in available_voices:
|
||||||
raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
raise ValueError(
|
||||||
|
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
|
||||||
# Combine voices
|
# Combine voices
|
||||||
combined_tensor = await tts_service.combine_voices(voices=voices)
|
combined_tensor = await tts_service.combine_voices(voices=voices)
|
||||||
|
|
|
@ -27,7 +27,14 @@ class AudioNormalizer:
|
||||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||||
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
|
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
|
||||||
|
|
||||||
def find_first_last_non_silent(self,audio_data: np.ndarray, chunk_text: str, speed: float, silence_threshold_db: int = -45, is_last_chunk: bool = False) -> tuple[int, int]:
|
def find_first_last_non_silent(
|
||||||
|
self,
|
||||||
|
audio_data: np.ndarray,
|
||||||
|
chunk_text: str,
|
||||||
|
speed: float,
|
||||||
|
silence_threshold_db: int = -45,
|
||||||
|
is_last_chunk: bool = False,
|
||||||
|
) -> tuple[int, int]:
|
||||||
"""Finds the indices of the first and last non-silent samples in audio data.
|
"""Finds the indices of the first and last non-silent samples in audio data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -46,14 +53,29 @@ class AudioNormalizer:
|
||||||
if len(split_character) > 0:
|
if len(split_character) > 0:
|
||||||
split_character = split_character[-1]
|
split_character = split_character[-1]
|
||||||
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
|
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
|
||||||
pad_multiplier=settings.dynamic_gap_trim_padding_char_multiplier[split_character]
|
pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
|
||||||
|
split_character
|
||||||
|
]
|
||||||
|
|
||||||
if not is_last_chunk:
|
if not is_last_chunk:
|
||||||
samples_to_pad_end= max(int((settings.dynamic_gap_trim_padding_ms * self.sample_rate * pad_multiplier) / 1000) - self.samples_to_pad_start, 0)
|
samples_to_pad_end = max(
|
||||||
|
int(
|
||||||
|
(
|
||||||
|
settings.dynamic_gap_trim_padding_ms
|
||||||
|
* self.sample_rate
|
||||||
|
* pad_multiplier
|
||||||
|
)
|
||||||
|
/ 1000
|
||||||
|
)
|
||||||
|
- self.samples_to_pad_start,
|
||||||
|
0,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
samples_to_pad_end = self.samples_to_pad_start
|
samples_to_pad_end = self.samples_to_pad_start
|
||||||
# Convert dBFS threshold to amplitude
|
# Convert dBFS threshold to amplitude
|
||||||
amplitude_threshold = np.iinfo(audio_data.dtype).max * (10 ** (silence_threshold_db / 20))
|
amplitude_threshold = np.iinfo(audio_data.dtype).max * (
|
||||||
|
10 ** (silence_threshold_db / 20)
|
||||||
|
)
|
||||||
# Find the first samples above the silence threshold at the start and end of the audio
|
# Find the first samples above the silence threshold at the start and end of the audio
|
||||||
non_silent_index_start, non_silent_index_end = None, None
|
non_silent_index_start, non_silent_index_end = None, None
|
||||||
|
|
||||||
|
@ -71,7 +93,10 @@ class AudioNormalizer:
|
||||||
if non_silent_index_start == None or non_silent_index_end == None:
|
if non_silent_index_start == None or non_silent_index_end == None:
|
||||||
return 0, len(audio_data)
|
return 0, len(audio_data)
|
||||||
|
|
||||||
return max(non_silent_index_start - self.samples_to_pad_start,0), min(non_silent_index_end + math.ceil(samples_to_pad_end / speed),len(audio_data))
|
return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
|
||||||
|
non_silent_index_end + math.ceil(samples_to_pad_end / speed),
|
||||||
|
len(audio_data),
|
||||||
|
)
|
||||||
|
|
||||||
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
"""Convert audio data to int16 range
|
"""Convert audio data to int16 range
|
||||||
|
@ -86,6 +111,7 @@ class AudioNormalizer:
|
||||||
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
"""Service for audio format conversions with streaming support"""
|
"""Service for audio format conversions with streaming support"""
|
||||||
|
|
||||||
|
@ -148,14 +174,14 @@ class AudioService:
|
||||||
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
|
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
|
||||||
|
|
||||||
if trim_audio == True:
|
if trim_audio == True:
|
||||||
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
audio_chunk = AudioService.trim_audio(
|
||||||
|
audio_chunk, chunk_text, speed, is_last_chunk, normalizer
|
||||||
|
)
|
||||||
|
|
||||||
# Write audio data first
|
# Write audio data first
|
||||||
if len(audio_chunk.audio) > 0:
|
if len(audio_chunk.audio) > 0:
|
||||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Then finalize if this is the last chunk
|
# Then finalize if this is the last chunk
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
final_data = writer.write_chunk(finalize=True)
|
final_data = writer.write_chunk(finalize=True)
|
||||||
|
@ -173,8 +199,15 @@ class AudioService:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
|
def trim_audio(
|
||||||
|
audio_chunk: AudioChunk,
|
||||||
|
chunk_text: str = "",
|
||||||
|
speed: float = 1,
|
||||||
|
is_last_chunk: bool = False,
|
||||||
|
normalizer: AudioNormalizer = None,
|
||||||
|
) -> AudioChunk:
|
||||||
"""Trim silence from start and end
|
"""Trim silence from start and end
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -195,11 +228,15 @@ class AudioService:
|
||||||
trimed_samples = 0
|
trimed_samples = 0
|
||||||
# Trim start and end if enough samples
|
# Trim start and end if enough samples
|
||||||
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||||
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
audio_chunk.audio = audio_chunk.audio[
|
||||||
|
normalizer.samples_to_trim : -normalizer.samples_to_trim
|
||||||
|
]
|
||||||
trimed_samples += normalizer.samples_to_trim
|
trimed_samples += normalizer.samples_to_trim
|
||||||
|
|
||||||
# Find non silent portion and trim
|
# Find non silent portion and trim
|
||||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
start_index, end_index = normalizer.find_first_last_non_silent(
|
||||||
|
audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
|
||||||
|
)
|
||||||
|
|
||||||
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
|
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
|
||||||
trimed_samples += start_index
|
trimed_samples += start_index
|
||||||
|
@ -209,4 +246,3 @@ class AudioService:
|
||||||
timestamp.start_time -= trimed_samples / 24000
|
timestamp.start_time -= trimed_samples / 24000
|
||||||
timestamp.end_time -= trimed_samples / 24000
|
timestamp.end_time -= trimed_samples / 24000
|
||||||
return audio_chunk
|
return audio_chunk
|
||||||
|
|
|
@ -21,13 +21,27 @@ class StreamingAudioWriter:
|
||||||
self.bytes_written = 0
|
self.bytes_written = 0
|
||||||
self.pts = 0
|
self.pts = 0
|
||||||
|
|
||||||
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
codec_map = {
|
||||||
|
"wav": "pcm_s16le",
|
||||||
|
"mp3": "mp3",
|
||||||
|
"opus": "libopus",
|
||||||
|
"flac": "flac",
|
||||||
|
"aac": "aac",
|
||||||
|
}
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
|
self.container = av.open(
|
||||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
self.output_buffer,
|
||||||
|
mode="w",
|
||||||
|
format=self.format if self.format != "aac" else "adts",
|
||||||
|
)
|
||||||
|
self.stream = self.container.add_stream(
|
||||||
|
codec_map[self.format],
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
|
)
|
||||||
self.stream.bit_rate = 128000
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
@ -66,10 +80,13 @@ class StreamingAudioWriter:
|
||||||
# Write raw bytes
|
# Write raw bytes
|
||||||
return audio_data.tobytes()
|
return audio_data.tobytes()
|
||||||
else:
|
else:
|
||||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
audio_data.reshape(1, -1),
|
||||||
|
format="s16",
|
||||||
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
|
)
|
||||||
frame.sample_rate = self.sample_rate
|
frame.sample_rate = self.sample_rate
|
||||||
|
|
||||||
|
|
||||||
frame.pts = self.pts
|
frame.pts = self.pts
|
||||||
self.pts += frame.samples
|
self.pts += frame.samples
|
||||||
|
|
||||||
|
@ -81,4 +98,3 @@ class StreamingAudioWriter:
|
||||||
self.output_buffer.seek(0)
|
self.output_buffer.seek(0)
|
||||||
self.output_buffer.truncate(0)
|
self.output_buffer.truncate(0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -58,23 +58,77 @@ VALID_TLDS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
VALID_UNITS = {
|
VALID_UNITS = {
|
||||||
"m":"meter", "cm":"centimeter", "mm":"millimeter", "km":"kilometer", "in":"inch", "ft":"foot", "yd":"yard", "mi":"mile", # Length
|
"m": "meter",
|
||||||
"g":"gram", "kg":"kilogram", "mg":"miligram", # Mass
|
"cm": "centimeter",
|
||||||
"s":"second", "ms":"milisecond", "min":"minutes", "h":"hour", # Time
|
"mm": "millimeter",
|
||||||
"l":"liter", "ml":"mililiter", "cl":"centiliter", "dl":"deciliter", # Volume
|
"km": "kilometer",
|
||||||
"kph":"kilometer per hour", "mph":"mile per hour","mi/h":"mile per hour", "m/s":"meter per second", "km/h":"kilometer per hour", "mm/s":"milimeter per second","cm/s":"centimeter per second", "ft/s":"feet per second","cm/h":"centimeter per day", # Speed
|
"in": "inch",
|
||||||
"°c":"degree celsius","c":"degree celsius", "°f":"degree fahrenheit","f":"degree fahrenheit", "k":"kelvin", # Temperature
|
"ft": "foot",
|
||||||
"pa":"pascal", "kpa":"kilopascal", "mpa":"megapascal", "atm":"atmosphere", # Pressure
|
"yd": "yard",
|
||||||
"hz":"hertz", "khz":"kilohertz", "mhz":"megahertz", "ghz":"gigahertz", # Frequency
|
"mi": "mile", # Length
|
||||||
"v":"volt", "kv":"kilovolt", "mv":"mergavolt", # Voltage
|
"g": "gram",
|
||||||
"a":"amp", "ma":"megaamp", "ka":"kiloamp", # Current
|
"kg": "kilogram",
|
||||||
"w":"watt", "kw":"kilowatt", "mw":"megawatt", # Power
|
"mg": "miligram", # Mass
|
||||||
"j":"joule", "kj":"kilojoule", "mj":"megajoule", # Energy
|
"s": "second",
|
||||||
"Ω":"ohm", "kΩ":"kiloohm", "mΩ":"megaohm", # Resistance (Ohm)
|
"ms": "milisecond",
|
||||||
"f":"farad", "µf":"microfarad", "nf":"nanofarad", "pf":"picofarad", # Capacitance
|
"min": "minutes",
|
||||||
"b":"bit", "kb":"kilobit", "mb":"megabit", "gb":"gigabit", "tb":"terabit", "pb":"petabit", # Data size
|
"h": "hour", # Time
|
||||||
"kbps":"kilobit per second","mbps":"megabit per second","gbps":"gigabit per second","tbps":"terabit per second",
|
"l": "liter",
|
||||||
"px":"pixel" # CSS units
|
"ml": "mililiter",
|
||||||
|
"cl": "centiliter",
|
||||||
|
"dl": "deciliter", # Volume
|
||||||
|
"kph": "kilometer per hour",
|
||||||
|
"mph": "mile per hour",
|
||||||
|
"mi/h": "mile per hour",
|
||||||
|
"m/s": "meter per second",
|
||||||
|
"km/h": "kilometer per hour",
|
||||||
|
"mm/s": "milimeter per second",
|
||||||
|
"cm/s": "centimeter per second",
|
||||||
|
"ft/s": "feet per second",
|
||||||
|
"cm/h": "centimeter per day", # Speed
|
||||||
|
"°c": "degree celsius",
|
||||||
|
"c": "degree celsius",
|
||||||
|
"°f": "degree fahrenheit",
|
||||||
|
"f": "degree fahrenheit",
|
||||||
|
"k": "kelvin", # Temperature
|
||||||
|
"pa": "pascal",
|
||||||
|
"kpa": "kilopascal",
|
||||||
|
"mpa": "megapascal",
|
||||||
|
"atm": "atmosphere", # Pressure
|
||||||
|
"hz": "hertz",
|
||||||
|
"khz": "kilohertz",
|
||||||
|
"mhz": "megahertz",
|
||||||
|
"ghz": "gigahertz", # Frequency
|
||||||
|
"v": "volt",
|
||||||
|
"kv": "kilovolt",
|
||||||
|
"mv": "mergavolt", # Voltage
|
||||||
|
"a": "amp",
|
||||||
|
"ma": "megaamp",
|
||||||
|
"ka": "kiloamp", # Current
|
||||||
|
"w": "watt",
|
||||||
|
"kw": "kilowatt",
|
||||||
|
"mw": "megawatt", # Power
|
||||||
|
"j": "joule",
|
||||||
|
"kj": "kilojoule",
|
||||||
|
"mj": "megajoule", # Energy
|
||||||
|
"Ω": "ohm",
|
||||||
|
"kΩ": "kiloohm",
|
||||||
|
"mΩ": "megaohm", # Resistance (Ohm)
|
||||||
|
"f": "farad",
|
||||||
|
"µf": "microfarad",
|
||||||
|
"nf": "nanofarad",
|
||||||
|
"pf": "picofarad", # Capacitance
|
||||||
|
"b": "bit",
|
||||||
|
"kb": "kilobit",
|
||||||
|
"mb": "megabit",
|
||||||
|
"gb": "gigabit",
|
||||||
|
"tb": "terabit",
|
||||||
|
"pb": "petabit", # Data size
|
||||||
|
"kbps": "kilobit per second",
|
||||||
|
"mbps": "megabit per second",
|
||||||
|
"gbps": "gigabit per second",
|
||||||
|
"tbps": "terabit per second",
|
||||||
|
"px": "pixel", # CSS units
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,12 +143,20 @@ URL_PATTERN = re.compile(
|
||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
|
UNIT_PATTERN = re.compile(
|
||||||
|
r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
|
||||||
|
+ "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
|
||||||
|
+ r"""){1}(?=[^\w\d]{1}|\b)""",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
|
TIME_PATTERN = re.compile(
|
||||||
|
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
INFLECT_ENGINE = inflect.engine()
|
INFLECT_ENGINE = inflect.engine()
|
||||||
|
|
||||||
|
|
||||||
def split_num(num: re.Match[str]) -> str:
|
def split_num(num: re.Match[str]) -> str:
|
||||||
"""Handle number splitting for various formats"""
|
"""Handle number splitting for various formats"""
|
||||||
num = num.group()
|
num = num.group()
|
||||||
|
@ -119,6 +181,7 @@ def split_num(num: re.Match[str]) -> str:
|
||||||
return f"{left} oh {right}{s}"
|
return f"{left} oh {right}{s}"
|
||||||
return f"{left} {right}{s}"
|
return f"{left} {right}{s}"
|
||||||
|
|
||||||
|
|
||||||
def handle_units(u: re.Match[str]) -> str:
|
def handle_units(u: re.Match[str]) -> str:
|
||||||
"""Converts units to their full form"""
|
"""Converts units to their full form"""
|
||||||
unit_string = u.group(6).strip()
|
unit_string = u.group(6).strip()
|
||||||
|
@ -137,11 +200,13 @@ def handle_units(u: re.Match[str]) -> str:
|
||||||
unit[0] = INFLECT_ENGINE.no(unit[0], number)
|
unit[0] = INFLECT_ENGINE.no(unit[0], number)
|
||||||
return " ".join(unit)
|
return " ".join(unit)
|
||||||
|
|
||||||
|
|
||||||
def conditional_int(number: float, threshold: float = 0.00001):
|
def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
if abs(round(number) - number) < threshold:
|
if abs(round(number) - number) < threshold:
|
||||||
return int(round(number))
|
return int(round(number))
|
||||||
return number
|
return number
|
||||||
|
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
|
@ -167,6 +232,7 @@ def handle_money(m: re.Match[str]) -> str:
|
||||||
|
|
||||||
return text_number
|
return text_number
|
||||||
|
|
||||||
|
|
||||||
def handle_decimal(num: re.Match[str]) -> str:
|
def handle_decimal(num: re.Match[str]) -> str:
|
||||||
"""Convert decimal numbers to spoken form"""
|
"""Convert decimal numbers to spoken form"""
|
||||||
a, b = num.group().split(".")
|
a, b = num.group().split(".")
|
||||||
|
@ -230,6 +296,7 @@ def handle_url(u: re.Match[str]) -> str:
|
||||||
# Clean up extra spaces
|
# Clean up extra spaces
|
||||||
return re.sub(r"\s+", " ", url).strip()
|
return re.sub(r"\s+", " ", url).strip()
|
||||||
|
|
||||||
|
|
||||||
def handle_phone_number(p: re.Match[str]) -> str:
|
def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
p = list(p.groups())
|
p = list(p.groups())
|
||||||
|
|
||||||
|
@ -238,7 +305,9 @@ def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
p[0] = p[0].replace("+", "")
|
p[0] = p[0].replace("+", "")
|
||||||
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
||||||
|
|
||||||
area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
|
area_code = INFLECT_ENGINE.number_to_words(
|
||||||
|
p[2].replace("(", "").replace(")", ""), group=1, comma=""
|
||||||
|
)
|
||||||
|
|
||||||
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
|
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
|
||||||
|
|
||||||
|
@ -246,10 +315,13 @@ def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
|
|
||||||
return ",".join([country_code, area_code, telephone_prefix, line_number])
|
return ",".join([country_code, area_code, telephone_prefix, line_number])
|
||||||
|
|
||||||
|
|
||||||
def handle_time(t: re.Match[str]) -> str:
|
def handle_time(t: re.Match[str]) -> str:
|
||||||
t = t.groups()
|
t = t.groups()
|
||||||
|
|
||||||
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
|
numbers = " ".join(
|
||||||
|
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
|
||||||
|
)
|
||||||
|
|
||||||
half = ""
|
half = ""
|
||||||
if t[2] is not None:
|
if t[2] is not None:
|
||||||
|
@ -257,6 +329,7 @@ def handle_time(t: re.Match[str]) -> str:
|
||||||
|
|
||||||
return numbers + half
|
return numbers + half
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
# Handle email addresses first if enabled
|
# Handle email addresses first if enabled
|
||||||
|
@ -277,7 +350,11 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
|
|
||||||
# Replace phone numbers:
|
# Replace phone numbers:
|
||||||
if normalization_options.phone_normalization:
|
if normalization_options.phone_normalization:
|
||||||
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
|
text = re.sub(
|
||||||
|
r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
|
||||||
|
handle_phone_number,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
|
@ -289,7 +366,10 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = text.replace(a, b + " ")
|
text = text.replace(a, b + " ")
|
||||||
|
|
||||||
# Handle simple time in the format of HH:MM:SS
|
# Handle simple time in the format of HH:MM:SS
|
||||||
text = TIME_PATTERN.sub(handle_time, text, )
|
text = TIME_PATTERN.sub(
|
||||||
|
handle_time,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up whitespace
|
# Clean up whitespace
|
||||||
text = re.sub(r"[^\S \n]", " ", text)
|
text = re.sub(r"[^\S \n]", " ", text)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from .vocabulary import tokenize
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
||||||
|
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
text: str, language: str = "a", skip_phonemize: bool = False
|
text: str, language: str = "a", skip_phonemize: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
@ -41,9 +42,7 @@ def process_text_chunk(
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(
|
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
||||||
text, language, normalize=False
|
|
||||||
) # Already normalized
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -88,7 +87,9 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
return process_text_chunk(text, language)
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
|
def get_sentence_info(
|
||||||
|
text: str, custom_phenomes_list: Dict[str, str]
|
||||||
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info."""
|
"""Process all sentences and return info."""
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||||
|
@ -99,10 +100,11 @@ def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[T
|
||||||
for replaced in range(min_value, phoneme_length):
|
for replaced in range(min_value, phoneme_length):
|
||||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
current_id = f"</|custom_phonemes_{replaced}|/>"
|
||||||
if current_id in sentence:
|
if current_id in sentence:
|
||||||
sentence = sentence.replace(current_id, custom_phenomes_list.pop(current_id))
|
sentence = sentence.replace(
|
||||||
|
current_id, custom_phenomes_list.pop(current_id)
|
||||||
|
)
|
||||||
min_value += 1
|
min_value += 1
|
||||||
|
|
||||||
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
|
|
||||||
if not sentence:
|
if not sentence:
|
||||||
|
@ -114,16 +116,18 @@ def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[T
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
|
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
|
||||||
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
|
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
|
||||||
phenomes_list[latest_id] = s.group(0).strip()
|
phenomes_list[latest_id] = s.group(0).strip()
|
||||||
return latest_id
|
return latest_id
|
||||||
|
|
||||||
|
|
||||||
async def smart_split(
|
async def smart_split(
|
||||||
text: str,
|
text: str,
|
||||||
max_tokens: int = settings.absolute_max_tokens,
|
max_tokens: int = settings.absolute_max_tokens,
|
||||||
lang_code: str = "a",
|
lang_code: str = "a",
|
||||||
normalization_options: NormalizationOptions = NormalizationOptions()
|
normalization_options: NormalizationOptions = NormalizationOptions(),
|
||||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -136,10 +140,14 @@ async def smart_split(
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
print(lang_code)
|
print(lang_code)
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
text = CUSTOM_PHONEMES.sub(
|
||||||
|
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
||||||
|
)
|
||||||
text = normalize_text(text, normalization_options)
|
text = normalize_text(text, normalization_options)
|
||||||
else:
|
else:
|
||||||
logger.info("Skipping text normalization as it is only supported for english")
|
logger.info(
|
||||||
|
"Skipping text normalization as it is only supported for english"
|
||||||
|
)
|
||||||
|
|
||||||
# Process all sentences
|
# Process all sentences
|
||||||
sentences = get_sentence_info(text, custom_phoneme_list)
|
sentences = get_sentence_info(text, custom_phoneme_list)
|
||||||
|
|
|
@ -69,7 +69,9 @@ class TTSService:
|
||||||
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
||||||
return
|
return
|
||||||
chunk_data = await AudioService.convert_audio(
|
chunk_data = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
AudioChunk(
|
||||||
|
np.array([], dtype=np.float32)
|
||||||
|
), # Dummy data for type checking
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
|
@ -114,13 +116,22 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
chunk_data = AudioService.trim_audio(
|
||||||
|
chunk_data, chunk_text, speed, is_last, normalizer
|
||||||
|
)
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
# For legacy backends, load voice tensor
|
# For legacy backends, load voice tensor
|
||||||
voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(
|
||||||
chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps)
|
voice_name, device=backend.device
|
||||||
|
)
|
||||||
|
chunk_data = await self.model_manager.generate(
|
||||||
|
tokens,
|
||||||
|
voice_tensor,
|
||||||
|
speed=speed,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
if chunk_data.audio is None:
|
if chunk_data.audio is None:
|
||||||
logger.error("Model generated None for audio chunk")
|
logger.error("Model generated None for audio chunk")
|
||||||
|
@ -146,7 +157,9 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
trimmed = AudioService.trim_audio(
|
||||||
|
chunk_data, chunk_text, speed, is_last, normalizer
|
||||||
|
)
|
||||||
yield trimmed
|
yield trimmed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process tokens: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
|
@ -178,7 +191,9 @@ class TTSService:
|
||||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||||
if len(split_voice) == 1:
|
if len(split_voice) == 1:
|
||||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||||
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
if (
|
||||||
|
"(" not in voice and ")" not in voice
|
||||||
|
) or settings.voice_weight_normalization == True:
|
||||||
path = await self._voice_manager.get_voice_path(voice)
|
path = await self._voice_manager.get_voice_path(voice)
|
||||||
if not path:
|
if not path:
|
||||||
raise RuntimeError(f"Voice not found: {voice}")
|
raise RuntimeError(f"Voice not found: {voice}")
|
||||||
|
@ -206,13 +221,19 @@ class TTSService:
|
||||||
|
|
||||||
# Load the first voice as the starting point for voices to be combined onto
|
# Load the first voice as the starting point for voices to be combined onto
|
||||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||||
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
combined_tensor = await self._load_voice_from_path(
|
||||||
|
path, split_voice[0][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||||
for operation_index in range(1, len(split_voice) - 1, 2):
|
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||||
# Get the voice path of the voice 1 index ahead of the operator
|
# Get the voice path of the voice 1 index ahead of the operator
|
||||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0])
|
path = await self._voice_manager.get_voice_path(
|
||||||
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
split_voice[operation_index + 1][0]
|
||||||
|
)
|
||||||
|
voice_tensor = await self._load_voice_from_path(
|
||||||
|
path, split_voice[operation_index + 1][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
# Either add or subtract the voice from the current combined voice
|
# Either add or subtract the voice from the current combined voice
|
||||||
if split_voice[operation_index] == "+":
|
if split_voice[operation_index] == "+":
|
||||||
|
@ -255,10 +276,16 @@ class TTSService:
|
||||||
|
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
logger.info(
|
||||||
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||||
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting
|
||||||
async for chunk_text, tokens in smart_split(text, lang_code=pipeline_lang_code, normalization_options=normalization_options):
|
async for chunk_text, tokens in smart_split(
|
||||||
|
text,
|
||||||
|
lang_code=pipeline_lang_code,
|
||||||
|
normalization_options=normalization_options,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
async for chunk_data in self._process_chunk(
|
async for chunk_data in self._process_chunk(
|
||||||
|
@ -286,10 +313,14 @@ class TTSService:
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
logger.warning(
|
||||||
|
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||||
|
)
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
logger.error(
|
||||||
|
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only finalize if we successfully processed at least one chunk
|
# Only finalize if we successfully processed at least one chunk
|
||||||
|
@ -332,7 +363,16 @@ class TTSService:
|
||||||
audio_data_chunks = []
|
audio_data_chunks = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None):
|
async for audio_stream_data in self.generate_audio_stream(
|
||||||
|
text,
|
||||||
|
voice,
|
||||||
|
writer,
|
||||||
|
speed=speed,
|
||||||
|
normalization_options=normalization_options,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
lang_code=lang_code,
|
||||||
|
output_format=None,
|
||||||
|
):
|
||||||
if len(audio_stream_data.audio) > 0:
|
if len(audio_stream_data.audio) > 0:
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
|
@ -384,11 +424,15 @@ class TTSService:
|
||||||
result = None
|
result = None
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
|
logger.info(
|
||||||
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use backend's pipeline management
|
# Use backend's pipeline management
|
||||||
for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens(
|
for r in backend._get_pipeline(
|
||||||
|
pipeline_lang_code
|
||||||
|
).generate_from_tokens(
|
||||||
tokens=phonemes, # Pass raw phonemes string
|
tokens=phonemes, # Pass raw phonemes string
|
||||||
voice=voice_path,
|
voice=voice_path,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
|
@ -406,7 +450,9 @@ class TTSService:
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
return result.audio.numpy(), processing_time
|
return result.audio.numpy(), processing_time
|
||||||
else:
|
else:
|
||||||
raise ValueError("Phoneme generation only supported with Kokoro V1 backend")
|
raise ValueError(
|
||||||
|
"Phoneme generation only supported with Kokoro V1 backend"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||||
|
|
|
@ -24,16 +24,12 @@ class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||||
else:
|
else:
|
||||||
self._content_iterable = iterate_in_threadpool(content)
|
self._content_iterable = iterate_in_threadpool(content)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def body_iterator() -> AsyncIterable[bytes]:
|
async def body_iterator() -> AsyncIterable[bytes]:
|
||||||
async for content_ in self._content_iterable:
|
async for content_ in self._content_iterable:
|
||||||
if isinstance(content_, BaseModel):
|
if isinstance(content_, BaseModel):
|
||||||
content_ = content_.model_dump()
|
content_ = content_.model_dump()
|
||||||
yield self.render(content_)
|
yield self.render(content_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.body_iterator = body_iterator()
|
self.body_iterator = body_iterator()
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
if media_type is not None:
|
if media_type is not None:
|
||||||
|
@ -42,10 +38,13 @@ class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||||
self.init_headers(headers)
|
self.init_headers(headers)
|
||||||
|
|
||||||
def render(self, content: typing.Any) -> bytes:
|
def render(self, content: typing.Any) -> bytes:
|
||||||
return (json.dumps(
|
return (
|
||||||
|
json.dumps(
|
||||||
content,
|
content,
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
allow_nan=False,
|
allow_nan=False,
|
||||||
indent=None,
|
indent=None,
|
||||||
separators=(",", ":"),
|
separators=(",", ":"),
|
||||||
) + "\n").encode("utf-8")
|
)
|
||||||
|
+ "\n"
|
||||||
|
).encode("utf-8")
|
||||||
|
|
|
@ -35,16 +35,38 @@ class CaptionedSpeechResponse(BaseModel):
|
||||||
|
|
||||||
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
||||||
audio_format: str = Field(..., description="The format of the output audio")
|
audio_format: str = Field(..., description="The format of the output audio")
|
||||||
timestamps: Optional[List[WordTimestamp]] = Field(..., description="Word-level timestamps")
|
timestamps: Optional[List[WordTimestamp]] = Field(
|
||||||
|
..., description="Word-level timestamps"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizationOptions(BaseModel):
|
class NormalizationOptions(BaseModel):
|
||||||
"""Options for the normalization system"""
|
"""Options for the normalization system"""
|
||||||
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
|
|
||||||
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
|
normalize: bool = Field(
|
||||||
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
default=True,
|
||||||
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
description="Normalizes input text to make it easier for the model to say",
|
||||||
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
|
)
|
||||||
phone_normalization: bool = Field(default=True, description="Changes phone numbers so they can be properly pronouced by kokoro")
|
unit_normalization: bool = Field(
|
||||||
|
default=False, description="Transforms units like 10KB to 10 kilobytes"
|
||||||
|
)
|
||||||
|
url_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes urls so they can be properly pronouced by kokoro",
|
||||||
|
)
|
||||||
|
email_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes emails so they can be properly pronouced by kokoro",
|
||||||
|
)
|
||||||
|
optional_pluralization_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Replaces (s) with s so some words get pronounced correctly",
|
||||||
|
)
|
||||||
|
phone_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||||
|
@ -62,10 +84,12 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default="mp3",
|
default="mp3",
|
||||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
||||||
)
|
)
|
||||||
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = Field(
|
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
|
||||||
|
Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional different format for the final download. If not provided, uses response_format.",
|
description="Optional different format for the final download. If not provided, uses response_format.",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
speed: float = Field(
|
speed: float = Field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
ge=0.25,
|
ge=0.25,
|
||||||
|
@ -86,7 +110,7 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
)
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default=NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description= "Options for the normalization system"
|
description="Options for the normalization system",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,5 +154,5 @@ class CaptionedSpeechRequest(BaseModel):
|
||||||
)
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default=NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description= "Options for the normalization system"
|
description="Options for the normalization system",
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,4 +69,3 @@ async def tts_service(mock_model_manager, mock_voice_manager):
|
||||||
def test_voice():
|
def test_voice():
|
||||||
"""Return a test voice name."""
|
"""Return a test voice name."""
|
||||||
return "voice1"
|
return "voice1"
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,9 @@ async def test_convert_to_mp3(sample_audio):
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
# Check MP3 header (ID3 or MPEG frame sync)
|
# Check MP3 header (ID3 or MPEG frame sync)
|
||||||
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
|
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
|
||||||
|
b"\xff\xfb"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -127,7 +129,9 @@ async def test_convert_to_aac(sample_audio):
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
# Check ADTS header (AAC)
|
# Check ADTS header (AAC)
|
||||||
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
|
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
|
||||||
|
b"\xff\xf1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -214,7 +218,6 @@ async def test_different_sample_rates(sample_audio):
|
||||||
sample_rates = [8000, 16000, 44100, 48000]
|
sample_rates = [8000, 16000, 44100, 48000]
|
||||||
|
|
||||||
for rate in sample_rates:
|
for rate in sample_rates:
|
||||||
|
|
||||||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
|
|
|
@ -14,14 +14,15 @@ def test_generate_captioned_speech():
|
||||||
|
|
||||||
mock_timestamps_response = MagicMock()
|
mock_timestamps_response = MagicMock()
|
||||||
mock_timestamps_response.status_code = 200
|
mock_timestamps_response.status_code = 200
|
||||||
mock_timestamps_response.content = json.dumps({
|
mock_timestamps_response.content = json.dumps(
|
||||||
|
{
|
||||||
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
|
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
|
||||||
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Patch the HTTP requests
|
# Patch the HTTP requests
|
||||||
with patch('requests.post', return_value=mock_timestamps_response):
|
with patch("requests.post", return_value=mock_timestamps_response):
|
||||||
|
|
||||||
# Import here to avoid module-level import issues
|
# Import here to avoid module-level import issues
|
||||||
from examples.captioned_speech_example import generate_captioned_speech
|
from examples.captioned_speech_example import generate_captioned_speech
|
||||||
|
|
||||||
|
|
|
@ -9,24 +9,44 @@ from api.src.structures.schemas import NormalizationOptions
|
||||||
def test_url_protocols():
|
def test_url_protocols():
|
||||||
"""Test URL protocol handling"""
|
"""Test URL protocol handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Check out https://example.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Check out https://example.com",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
== "Check out https example dot com"
|
== "Check out https example dot com"
|
||||||
)
|
)
|
||||||
assert normalize_text("Visit http://site.com",normalization_options=NormalizationOptions()) == "Visit http site dot com"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Go to https://test.org/path",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Visit http://site.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit http site dot com"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Go to https://test.org/path", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Go to https test dot org slash path"
|
== "Go to https test dot org slash path"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_www():
|
def test_url_www():
|
||||||
"""Test www prefix handling"""
|
"""Test www prefix handling"""
|
||||||
assert normalize_text("Go to www.example.com",normalization_options=NormalizationOptions()) == "Go to www example dot com"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Visit www.test.org/docs",normalization_options=NormalizationOptions()) == "Visit www test dot org slash docs"
|
normalize_text(
|
||||||
|
"Go to www.example.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Go to www example dot com"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Check www.site.com?q=test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit www test dot org slash docs"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Check www site dot com question-mark q equals test"
|
== "Check www site dot com question-mark q equals test"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,15 +54,21 @@ def test_url_www():
|
||||||
def test_url_localhost():
|
def test_url_localhost():
|
||||||
"""Test localhost URL handling"""
|
"""Test localhost URL handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Running on localhost:7860",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Running on localhost colon 78 60"
|
== "Running on localhost colon 78 60"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Server at localhost:8080/api",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Server at localhost colon 80 80 slash api"
|
== "Server at localhost colon 80 80 slash api"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Test localhost:3000/test?v=1",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,48 +76,104 @@ def test_url_localhost():
|
||||||
def test_url_ip_addresses():
|
def test_url_ip_addresses():
|
||||||
"""Test IP address URL handling"""
|
"""Test IP address URL handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Access 0.0.0.0:9090/test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("API at 192.168.1.1:8000",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
||||||
)
|
)
|
||||||
assert normalize_text("Server 127.0.0.1",normalization_options=NormalizationOptions()) == "Server 127 dot 0 dot 0 dot 1"
|
assert (
|
||||||
|
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
||||||
|
== "Server 127 dot 0 dot 0 dot 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_raw_domains():
|
def test_url_raw_domains():
|
||||||
"""Test raw domain handling"""
|
"""Test raw domain handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Visit google.com/search",normalization_options=NormalizationOptions()) == "Visit google dot com slash search"
|
normalize_text(
|
||||||
|
"Visit google.com/search", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit google dot com slash search"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Go to example.com/path?q=test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Go to example.com/path?q=test",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
== "Go to example dot com slash path question-mark q equals test"
|
== "Go to example dot com slash path question-mark q equals test"
|
||||||
)
|
)
|
||||||
assert normalize_text("Check docs.test.com",normalization_options=NormalizationOptions()) == "Check docs dot test dot com"
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Check docs.test.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Check docs dot test dot com"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_email_addresses():
|
def test_url_email_addresses():
|
||||||
"""Test email address handling"""
|
"""Test email address handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Email me at user@example.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Email me at user@example.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Email me at user at example dot com"
|
== "Email me at user at example dot com"
|
||||||
)
|
)
|
||||||
assert normalize_text("Contact admin@test.org",normalization_options=NormalizationOptions()) == "Contact admin at test dot org"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Send to test.user@site.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Contact admin@test.org", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Contact admin at test dot org"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Send to test.user@site.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Send to test dot user at site dot com"
|
== "Send to test dot user at site dot com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_money():
|
def test_money():
|
||||||
"""Test that money text is normalized correctly"""
|
"""Test that money text is normalized correctly"""
|
||||||
assert normalize_text("He lost $5.3 thousand.",normalization_options=NormalizationOptions()) == "He lost five point three thousand dollars."
|
assert (
|
||||||
assert normalize_text("To put it weirdly -$6.9 million",normalization_options=NormalizationOptions()) == "To put it weirdly minus six point nine million dollars"
|
normalize_text(
|
||||||
assert normalize_text("It costs $50.3.",normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents."
|
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "He lost five point three thousand dollars."
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"To put it weirdly -$6.9 million",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "To put it weirdly minus six point nine million dollars"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
||||||
|
== "It costs fifty dollars and thirty cents."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_non_url_text():
|
def test_non_url_text():
|
||||||
"""Test that non-URL text is unaffected"""
|
"""Test that non-URL text is unaffected"""
|
||||||
assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
|
assert (
|
||||||
assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
|
normalize_text(
|
||||||
assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs fifty dollars."
|
"This is not.a.url text", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "This is not-a-url text"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Hello, how are you today?", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Hello, how are you today?"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||||
|
== "It costs fifty dollars."
|
||||||
|
)
|
||||||
|
|
|
@ -113,7 +113,6 @@ def test_retrieve_model(mock_openai_mappings):
|
||||||
assert error["detail"]["type"] == "invalid_request_error"
|
assert error["detail"]["type"] == "invalid_request_error"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_tts_service_initialization():
|
async def test_get_tts_service_initialization():
|
||||||
"""Test TTSService initialization"""
|
"""Test TTSService initialization"""
|
||||||
|
@ -263,7 +262,9 @@ def test_openai_speech_endpoint(
|
||||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||||
# Configure mocks
|
# Configure mocks
|
||||||
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
||||||
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
|
mock_convert.return_value = AudioChunk(
|
||||||
|
np.zeros(1000, np.int16), output=mock_audio_bytes
|
||||||
|
)
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/audio/speech",
|
"/v1/audio/speech",
|
||||||
|
|
|
@ -44,9 +44,12 @@ def test_get_sentence_info():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_sentence_info_phenomoes():
|
def test_get_sentence_info_phenomoes():
|
||||||
"""Test sentence splitting and info extraction."""
|
"""Test sentence splitting and info extraction."""
|
||||||
text = "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
text = (
|
||||||
|
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||||
|
)
|
||||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
|
@ -58,6 +61,7 @@ def test_get_sentence_info_phenomoes():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_short_text():
|
async def test_smart_split_short_text():
|
||||||
"""Test smart splitting with text under max tokens."""
|
"""Test smart splitting with text under max tokens."""
|
||||||
|
|
|
@ -28,25 +28,33 @@ import requests
|
||||||
def setup_args():
|
def setup_args():
|
||||||
"""Parse command line arguments"""
|
"""Parse command line arguments"""
|
||||||
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
||||||
parser.add_argument("--url", default="http://localhost:8880",
|
parser.add_argument(
|
||||||
help="Base URL of the Kokoro TTS service")
|
"--url",
|
||||||
parser.add_argument("--threads", type=int, default=8,
|
default="http://localhost:8880",
|
||||||
help="Number of concurrent threads to use")
|
help="Base URL of the Kokoro TTS service",
|
||||||
parser.add_argument("--iterations", type=int, default=5,
|
)
|
||||||
help="Number of iterations per thread")
|
parser.add_argument(
|
||||||
parser.add_argument("--voice", default="af_heart",
|
"--threads", type=int, default=8, help="Number of concurrent threads to use"
|
||||||
help="Voice to use for TTS")
|
)
|
||||||
parser.add_argument("--output-dir", default="./tts_test_output",
|
parser.add_argument(
|
||||||
help="Directory to save output files")
|
"--iterations", type=int, default=5, help="Number of iterations per thread"
|
||||||
parser.add_argument("--debug", action="store_true",
|
)
|
||||||
help="Enable debug logging")
|
parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="./tts_test_output",
|
||||||
|
help="Directory to save output files",
|
||||||
|
)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def generate_test_sentence(thread_id, iteration):
|
def generate_test_sentence(thread_id, iteration):
|
||||||
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
||||||
return f"This is test sentence number {thread_id}-{iteration}. " \
|
return (
|
||||||
|
f"This is test sentence number {thread_id}-{iteration}. "
|
||||||
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_message(message, debug=False, is_error=False):
|
def log_message(message, debug=False, is_error=False):
|
||||||
|
@ -74,7 +82,9 @@ def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
||||||
f.write(text)
|
f.write(text)
|
||||||
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
|
||||||
|
)
|
||||||
|
|
||||||
# Make the TTS request
|
# Make the TTS request
|
||||||
try:
|
try:
|
||||||
|
@ -86,56 +96,102 @@ def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": voice,
|
"voice": voice,
|
||||||
"response_format": "wav"
|
"response_format": "wav",
|
||||||
},
|
},
|
||||||
headers={"Accept": "audio/wav"},
|
headers={"Accept": "audio/wav"},
|
||||||
timeout=60 # Increase timeout to 60 seconds
|
timeout=60, # Increase timeout to 60 seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug)
|
log_message(
|
||||||
log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug)
|
f"Thread {test_id}: Response status code: {response.status_code}", debug
|
||||||
log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug)
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Response content length: {len(response.content)} bytes",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: API error: {response.status_code} - {response.text}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if we got valid audio data
|
# Check if we got valid audio data
|
||||||
if len(response.content) < 100: # Sanity check - WAV files should be larger than this
|
if (
|
||||||
log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True)
|
len(response.content) < 100
|
||||||
log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True)
|
): # Sanity check - WAV files should be larger than this
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Save the audio output with explicit error handling
|
# Save the audio output with explicit error handling
|
||||||
try:
|
try:
|
||||||
with open(output_file, "wb") as f:
|
with open(output_file, "wb") as f:
|
||||||
bytes_written = f.write(response.content)
|
bytes_written = f.write(response.content)
|
||||||
log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the WAV file exists and has content
|
# Verify the WAV file exists and has content
|
||||||
if os.path.exists(output_file):
|
if os.path.exists(output_file):
|
||||||
file_size = os.path.getsize(output_file)
|
file_size = os.path.getsize(output_file)
|
||||||
log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
# Validate WAV file by reading its headers
|
# Validate WAV file by reading its headers
|
||||||
try:
|
try:
|
||||||
with wave.open(output_file, 'rb') as wav_file:
|
with wave.open(output_file, "rb") as wav_file:
|
||||||
channels = wav_file.getnchannels()
|
channels = wav_file.getnchannels()
|
||||||
sample_width = wav_file.getsampwidth()
|
sample_width = wav_file.getsampwidth()
|
||||||
framerate = wav_file.getframerate()
|
framerate = wav_file.getframerate()
|
||||||
frames = wav_file.getnframes()
|
frames = wav_file.getnframes()
|
||||||
log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
log_message(
|
||||||
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug)
|
f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
||||||
|
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
except Exception as wav_error:
|
except Exception as wav_error:
|
||||||
log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: File was not created: {output_file}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
except Exception as save_error:
|
except Exception as save_error:
|
||||||
log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Error saving audio file: {str(save_error)}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
|
@ -152,10 +208,16 @@ def worker_task(thread_id, args):
|
||||||
iteration = i + 1
|
iteration = i + 1
|
||||||
test_id = f"{thread_id:02d}_{iteration:02d}"
|
test_id = f"{thread_id:02d}_{iteration:02d}"
|
||||||
text = generate_test_sentence(thread_id, iteration)
|
text = generate_test_sentence(thread_id, iteration)
|
||||||
success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug)
|
success = request_tts(
|
||||||
|
args.url, test_id, text, args.voice, args.output_dir, args.debug
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {thread_id}: Iteration {iteration} failed",
|
||||||
|
args.debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Small delay between iterations to avoid overwhelming the API
|
# Small delay between iterations to avoid overwhelming the API
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
@ -172,9 +234,14 @@ def run_test(args):
|
||||||
with open(test_file, "w") as f:
|
with open(test_file, "w") as f:
|
||||||
f.write("Testing write access\n")
|
f.write("Testing write access\n")
|
||||||
os.remove(test_file)
|
os.remove(test_file)
|
||||||
log_message(f"Successfully verified write access to output directory: {args.output_dir}")
|
log_message(
|
||||||
|
f"Successfully verified write access to output directory: {args.output_dir}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
||||||
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
||||||
|
|
||||||
|
@ -184,13 +251,21 @@ def run_test(args):
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
||||||
else:
|
else:
|
||||||
log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Kokoro TTS service health check returned status {response.status_code}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Record start time
|
# Record start time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread")
|
log_message(
|
||||||
|
f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
|
||||||
|
)
|
||||||
|
|
||||||
# Create and start worker threads
|
# Create and start worker threads
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
||||||
|
@ -203,7 +278,9 @@ def run_test(args):
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread execution failed: {str(e)}", args.debug, is_error=True
|
||||||
|
)
|
||||||
|
|
||||||
# Record end time and print summary
|
# Record end time and print summary
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -214,8 +291,12 @@ def run_test(args):
|
||||||
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
||||||
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
||||||
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
||||||
log_message("To verify, listen to the audio files and check if they match the text files")
|
log_message(
|
||||||
log_message("If you hear audio describing a different test number than the filename, you've found a race condition")
|
"To verify, listen to the audio files and check if they match the text files"
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"If you hear audio describing a different test number than the filename, you've found a race condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def analyze_audio_files(output_dir):
|
def analyze_audio_files(output_dir):
|
||||||
|
@ -227,30 +308,34 @@ def analyze_audio_files(output_dir):
|
||||||
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
||||||
|
|
||||||
if len(wav_files) == 0:
|
if len(wav_files) == 0:
|
||||||
log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True)
|
log_message(
|
||||||
log_message("Check the connection to the TTS service and the response status codes above.", is_error=True)
|
"No WAV files found! This indicates the TTS service requests may be failing.",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"Check the connection to the TTS service and the response status codes above.",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
file_stats = []
|
file_stats = []
|
||||||
for wav_path in wav_files:
|
for wav_path in wav_files:
|
||||||
try:
|
try:
|
||||||
with wave.open(str(wav_path), 'rb') as wav_file:
|
with wave.open(str(wav_path), "rb") as wav_file:
|
||||||
frames = wav_file.getnframes()
|
frames = wav_file.getnframes()
|
||||||
rate = wav_file.getframerate()
|
rate = wav_file.getframerate()
|
||||||
duration = frames / rate
|
duration = frames / rate
|
||||||
|
|
||||||
# Get corresponding text
|
# Get corresponding text
|
||||||
text_path = wav_path.with_suffix('.txt')
|
text_path = wav_path.with_suffix(".txt")
|
||||||
if text_path.exists():
|
if text_path.exists():
|
||||||
with open(text_path, 'r') as text_file:
|
with open(text_path, "r") as text_file:
|
||||||
text = text_file.read().strip()
|
text = text_file.read().strip()
|
||||||
else:
|
else:
|
||||||
text = "N/A"
|
text = "N/A"
|
||||||
|
|
||||||
file_stats.append({
|
file_stats.append(
|
||||||
'filename': wav_path.name,
|
{"filename": wav_path.name, "duration": duration, "text": text}
|
||||||
'duration': duration,
|
)
|
||||||
'text': text
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
||||||
|
|
||||||
|
@ -260,12 +345,17 @@ def analyze_audio_files(output_dir):
|
||||||
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
||||||
log_message("-" * 92)
|
log_message("-" * 92)
|
||||||
for stat in file_stats:
|
for stat in file_stats:
|
||||||
log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}")
|
log_message(
|
||||||
|
f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
|
||||||
|
)
|
||||||
|
|
||||||
# List missing WAV files where text files exist
|
# List missing WAV files where text files exist
|
||||||
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
||||||
if missing_wavs:
|
if missing_wavs:
|
||||||
log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True)
|
log_message(
|
||||||
|
f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
||||||
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
||||||
if len(missing_wavs) > 10:
|
if len(missing_wavs) > 10:
|
||||||
|
@ -280,5 +370,9 @@ if __name__ == "__main__":
|
||||||
log_message("\nNext Steps:")
|
log_message("\nNext Steps:")
|
||||||
log_message("1. Listen to the generated audio files")
|
log_message("1. Listen to the generated audio files")
|
||||||
log_message("2. Verify if each audio correctly says its ID number")
|
log_message("2. Verify if each audio correctly says its ID number")
|
||||||
log_message("3. Check for any mismatches between the audio content and the text files")
|
log_message(
|
||||||
log_message("4. If mismatches are found, you've successfully reproduced the race condition")
|
"3. Check for any mismatches between the audio content and the text files"
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"4. If mismatches are found, you've successfully reproduced the race condition"
|
||||||
|
)
|
||||||
|
|
|
@ -36,7 +36,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
f = open(f"outputstream.{Type}", "wb")
|
f = open(f"outputstream.{Type}", "wb")
|
||||||
|
@ -53,7 +53,11 @@ for chunk in response.iter_lines(decode_unicode=True):
|
||||||
f.write(chunk_audio)
|
f.write(chunk_audio)
|
||||||
|
|
||||||
# Print word level timestamps
|
# Print word level timestamps
|
||||||
last_chunks={"start_time":chunk_json["timestamps"][-10]["start_time"],"end_time":chunk_json["timestamps"][-3]["end_time"],"word":" ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]])}
|
last_chunks = {
|
||||||
|
"start_time": chunk_json["timestamps"][-10]["start_time"],
|
||||||
|
"end_time": chunk_json["timestamps"][-3]["end_time"],
|
||||||
|
"word": " ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]]),
|
||||||
|
}
|
||||||
|
|
||||||
print(f"CUTTING TO {last_chunks['word']}")
|
print(f"CUTTING TO {last_chunks['word']}")
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
||||||
|
|
|
@ -12,6 +12,7 @@ def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
return int(round(number))
|
return int(round(number))
|
||||||
return number
|
return number
|
||||||
|
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(f"outputnostream.{Type}", "wb") as f:
|
with open(f"outputnostream.{Type}", "wb") as f:
|
||||||
|
|
|
@ -91,9 +91,7 @@ def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
|
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output", required=True, help="Output directory for model files"
|
||||||
required=True,
|
|
||||||
help="Output directory for model files"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -9,6 +9,7 @@ import sys
|
||||||
# Find the misaki package
|
# Find the misaki package
|
||||||
try:
|
try:
|
||||||
import misaki
|
import misaki
|
||||||
|
|
||||||
misaki_path = os.path.dirname(misaki.__file__)
|
misaki_path = os.path.dirname(misaki.__file__)
|
||||||
print(f"Found misaki package at: {misaki_path}")
|
print(f"Found misaki package at: {misaki_path}")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -23,7 +24,7 @@ if not os.path.exists(espeak_file):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Read the current content
|
# Read the current content
|
||||||
with open(espeak_file, 'r') as f:
|
with open(espeak_file, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check if the problematic line exists
|
# Check if the problematic line exists
|
||||||
|
@ -32,11 +33,11 @@ if "EspeakWrapper.set_data_path(espeakng_loader.get_data_path())" in content:
|
||||||
new_content = content.replace(
|
new_content = content.replace(
|
||||||
"EspeakWrapper.set_data_path(espeakng_loader.get_data_path())",
|
"EspeakWrapper.set_data_path(espeakng_loader.get_data_path())",
|
||||||
"# Fixed line to use data_path attribute instead of set_data_path method\n"
|
"# Fixed line to use data_path attribute instead of set_data_path method\n"
|
||||||
"EspeakWrapper.data_path = espeakng_loader.get_data_path()"
|
"EspeakWrapper.data_path = espeakng_loader.get_data_path()",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write the modified content back
|
# Write the modified content back
|
||||||
with open(espeak_file, 'w') as f:
|
with open(espeak_file, "w") as f:
|
||||||
f.write(new_content)
|
f.write(new_content)
|
||||||
|
|
||||||
print(f"Successfully patched {espeak_file}")
|
print(f"Successfully patched {espeak_file}")
|
||||||
|
|
|
@ -39,15 +39,13 @@ def extract_dependency_info():
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def run_pytest_with_coverage():
|
def run_pytest_with_coverage():
|
||||||
"""Run pytest with coverage and return the results"""
|
"""Run pytest with coverage and return the results"""
|
||||||
try:
|
try:
|
||||||
# Run pytest with coverage
|
# Run pytest with coverage
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["pytest", "--cov=api", "-v"],
|
["pytest", "--cov=api", "-v"], capture_output=True, text=True, check=True
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract test results
|
# Extract test results
|
||||||
|
@ -56,10 +54,7 @@ def run_pytest_with_coverage():
|
||||||
|
|
||||||
# Extract coverage from .coverage file
|
# Extract coverage from .coverage file
|
||||||
coverage_output = subprocess.run(
|
coverage_output = subprocess.run(
|
||||||
["coverage", "report"],
|
["coverage", "report"], capture_output=True, text=True, check=True
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True
|
|
||||||
).stdout
|
).stdout
|
||||||
|
|
||||||
# Extract total coverage percentage
|
# Extract total coverage percentage
|
||||||
|
@ -72,6 +67,7 @@ def run_pytest_with_coverage():
|
||||||
print(f"Output: {e.output}")
|
print(f"Output: {e.output}")
|
||||||
return 0, "0"
|
return 0, "0"
|
||||||
|
|
||||||
|
|
||||||
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
"""Update the badges in the README file"""
|
"""Update the badges in the README file"""
|
||||||
readme_path = Path("README.md")
|
readme_path = Path("README.md")
|
||||||
|
@ -83,16 +79,16 @@ def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
|
|
||||||
# Update tests badge
|
# Update tests badge
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)',
|
r"!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)",
|
||||||
f'',
|
f"",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update coverage badge
|
# Update coverage badge
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)',
|
r"!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)",
|
||||||
f'',
|
f"",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update kokoro badge
|
# Update kokoro badge
|
||||||
|
@ -100,9 +96,9 @@ def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
# Find badge like kokoro-v0.9.2::abcdefg-BB5420 or kokoro-v0.9.2-BB5420
|
# Find badge like kokoro-v0.9.2::abcdefg-BB5420 or kokoro-v0.9.2-BB5420
|
||||||
kokoro_version = dep_info["kokoro"]["version"]
|
kokoro_version = dep_info["kokoro"]["version"]
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'(!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-)[^)-]+(-BB5420\))',
|
r"(!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-)[^)-]+(-BB5420\))",
|
||||||
lambda m: f"{m.group(1)}{kokoro_version}{m.group(2)}",
|
lambda m: f"{m.group(1)}{kokoro_version}{m.group(2)}",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update misaki badge
|
# Update misaki badge
|
||||||
|
@ -110,14 +106,15 @@ def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
# Find badge like misaki-v0.9.3::abcdefg-B8860B or misaki-v0.9.3-B8860B
|
# Find badge like misaki-v0.9.3::abcdefg-B8860B or misaki-v0.9.3-B8860B
|
||||||
misaki_version = dep_info["misaki"]["version"]
|
misaki_version = dep_info["misaki"]["version"]
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'(!\[Misaki\]\(https://img\.shields\.io/badge/misaki-)[^)-]+(-B8860B\))',
|
r"(!\[Misaki\]\(https://img\.shields\.io/badge/misaki-)[^)-]+(-B8860B\))",
|
||||||
lambda m: f"{m.group(1)}{misaki_version}{m.group(2)}",
|
lambda m: f"{m.group(1)}{misaki_version}{m.group(2)}",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
readme_path.write_text(content)
|
readme_path.write_text(content)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Get dependency info
|
# Get dependency info
|
||||||
dep_info = extract_dependency_info()
|
dep_info = extract_dependency_info()
|
||||||
|
@ -137,5 +134,6 @@ def main():
|
||||||
else:
|
else:
|
||||||
print("Failed to update badges")
|
print("Failed to update badges")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -21,6 +21,7 @@ HELM_CHART_FILE = ROOT_DIR / "charts" / "kokoro-fastapi" / "Chart.yaml"
|
||||||
README_FILE = ROOT_DIR / "README.md"
|
README_FILE = ROOT_DIR / "README.md"
|
||||||
# --- End Configuration ---
|
# --- End Configuration ---
|
||||||
|
|
||||||
|
|
||||||
def update_pyproject(version: str):
|
def update_pyproject(version: str):
|
||||||
"""Updates the version in pyproject.toml"""
|
"""Updates the version in pyproject.toml"""
|
||||||
if not PYPROJECT_FILE.exists():
|
if not PYPROJECT_FILE.exists():
|
||||||
|
@ -42,13 +43,16 @@ def update_pyproject(version: str):
|
||||||
print(f"Already up-to-date: {PYPROJECT_FILE} (version {version})")
|
print(f"Already up-to-date: {PYPROJECT_FILE} (version {version})")
|
||||||
else:
|
else:
|
||||||
# Perform replacement
|
# Perform replacement
|
||||||
new_content = re.sub(pattern, rf'\1"{version}"', content, count=1, flags=re.MULTILINE)
|
new_content = re.sub(
|
||||||
|
pattern, rf'\1"{version}"', content, count=1, flags=re.MULTILINE
|
||||||
|
)
|
||||||
PYPROJECT_FILE.write_text(new_content)
|
PYPROJECT_FILE.write_text(new_content)
|
||||||
print(f"Updated {PYPROJECT_FILE} from {current_version} to {version}")
|
print(f"Updated {PYPROJECT_FILE} from {current_version} to {version}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {PYPROJECT_FILE}: {e}")
|
print(f"Error processing {PYPROJECT_FILE}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def update_helm_chart(version: str):
|
def update_helm_chart(version: str):
|
||||||
"""Updates the version and appVersion in the Helm chart"""
|
"""Updates the version and appVersion in the Helm chart"""
|
||||||
if not HELM_CHART_FILE.exists():
|
if not HELM_CHART_FILE.exists():
|
||||||
|
@ -65,31 +69,48 @@ def update_helm_chart(version: str):
|
||||||
version_pattern = r"^(version:\s*)(\S+)"
|
version_pattern = r"^(version:\s*)(\S+)"
|
||||||
current_version_match = re.search(version_pattern, content, flags=re.MULTILINE)
|
current_version_match = re.search(version_pattern, content, flags=re.MULTILINE)
|
||||||
if current_version_match and current_version_match.group(2) != version:
|
if current_version_match and current_version_match.group(2) != version:
|
||||||
content = re.sub(version_pattern, rf"\g<1>{version}", content, count=1, flags=re.MULTILINE)
|
content = re.sub(
|
||||||
print(f"Updating 'version' in {HELM_CHART_FILE} from {current_version_match.group(2)} to {version}")
|
version_pattern,
|
||||||
|
rf"\g<1>{version}",
|
||||||
|
content,
|
||||||
|
count=1,
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Updating 'version' in {HELM_CHART_FILE} from {current_version_match.group(2)} to {version}"
|
||||||
|
)
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
elif current_version_match:
|
elif current_version_match:
|
||||||
print(f"Already up-to-date: 'version' in {HELM_CHART_FILE} is {version}")
|
print(f"Already up-to-date: 'version' in {HELM_CHART_FILE} is {version}")
|
||||||
else:
|
else:
|
||||||
print(f"Warning: 'version:' pattern not found in {HELM_CHART_FILE}")
|
print(f"Warning: 'version:' pattern not found in {HELM_CHART_FILE}")
|
||||||
|
|
||||||
|
|
||||||
# Update 'appVersion:' line (quoted or unquoted)
|
# Update 'appVersion:' line (quoted or unquoted)
|
||||||
# Looks for 'appVersion:' followed by optional whitespace, optional quote, the version, optional quote
|
# Looks for 'appVersion:' followed by optional whitespace, optional quote, the version, optional quote
|
||||||
app_version_pattern = r"^(appVersion:\s*)(\"?)([^\"\s]+)(\"?)"
|
app_version_pattern = r"^(appVersion:\s*)(\"?)([^\"\s]+)(\"?)"
|
||||||
current_app_version_match = re.search(app_version_pattern, content, flags=re.MULTILINE)
|
current_app_version_match = re.search(
|
||||||
|
app_version_pattern, content, flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
if current_app_version_match:
|
if current_app_version_match:
|
||||||
leading_whitespace = current_app_version_match.group(1) # e.g., "appVersion: "
|
leading_whitespace = current_app_version_match.group(
|
||||||
|
1
|
||||||
|
) # e.g., "appVersion: "
|
||||||
opening_quote = current_app_version_match.group(2) # e.g., '"' or ''
|
opening_quote = current_app_version_match.group(2) # e.g., '"' or ''
|
||||||
current_app_ver = current_app_version_match.group(3) # e.g., '0.2.0'
|
current_app_ver = current_app_version_match.group(3) # e.g., '0.2.0'
|
||||||
closing_quote = current_app_version_match.group(4) # e.g., '"' or ''
|
closing_quote = current_app_version_match.group(4) # e.g., '"' or ''
|
||||||
|
|
||||||
# Check if quotes were consistent (both present or both absent)
|
# Check if quotes were consistent (both present or both absent)
|
||||||
if opening_quote != closing_quote:
|
if opening_quote != closing_quote:
|
||||||
print(f"Warning: Inconsistent quotes found for appVersion in {HELM_CHART_FILE}. Skipping update for this line.")
|
print(
|
||||||
elif current_app_ver == version and opening_quote == '"': # Check if already correct *and* quoted
|
f"Warning: Inconsistent quotes found for appVersion in {HELM_CHART_FILE}. Skipping update for this line."
|
||||||
print(f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is \"{version}\"")
|
)
|
||||||
|
elif (
|
||||||
|
current_app_ver == version and opening_quote == '"'
|
||||||
|
): # Check if already correct *and* quoted
|
||||||
|
print(
|
||||||
|
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is \"{version}\""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Always replace with the quoted version
|
# Always replace with the quoted version
|
||||||
replacement = f'{leading_whitespace}"{version}"' # Ensure quotes
|
replacement = f'{leading_whitespace}"{version}"' # Ensure quotes
|
||||||
|
@ -98,20 +119,30 @@ def update_helm_chart(version: str):
|
||||||
|
|
||||||
# Only report update if the displayed value actually changes
|
# Only report update if the displayed value actually changes
|
||||||
if original_display != target_display:
|
if original_display != target_display:
|
||||||
content = re.sub(app_version_pattern, replacement, content, count=1, flags=re.MULTILINE)
|
content = re.sub(
|
||||||
print(f"Updating 'appVersion' in {HELM_CHART_FILE} from {original_display} to {target_display}")
|
app_version_pattern,
|
||||||
|
replacement,
|
||||||
|
content,
|
||||||
|
count=1,
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Updating 'appVersion' in {HELM_CHART_FILE} from {original_display} to {target_display}"
|
||||||
|
)
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
else:
|
else:
|
||||||
# It matches the target version but might need quoting fixed silently if we didn't update
|
# It matches the target version but might need quoting fixed silently if we didn't update
|
||||||
# Or it was already correct. Check if content changed. If not, report up-to-date.
|
# Or it was already correct. Check if content changed. If not, report up-to-date.
|
||||||
if not (content != original_content and updated_count > 0): # Avoid double message if version also changed
|
if not (
|
||||||
print(f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is {target_display}")
|
content != original_content and updated_count > 0
|
||||||
|
): # Avoid double message if version also changed
|
||||||
|
print(
|
||||||
|
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is {target_display}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Warning: 'appVersion:' pattern not found in {HELM_CHART_FILE}")
|
print(f"Warning: 'appVersion:' pattern not found in {HELM_CHART_FILE}")
|
||||||
|
|
||||||
|
|
||||||
# Write back only if changes were made
|
# Write back only if changes were made
|
||||||
if content != original_content:
|
if content != original_content:
|
||||||
HELM_CHART_FILE.write_text(content)
|
HELM_CHART_FILE.write_text(content)
|
||||||
|
@ -120,10 +151,10 @@ def update_helm_chart(version: str):
|
||||||
# If no updates were made but patterns were found, confirm it's up-to-date overall
|
# If no updates were made but patterns were found, confirm it's up-to-date overall
|
||||||
print(f"Already up-to-date: {HELM_CHART_FILE} (version {version})")
|
print(f"Already up-to-date: {HELM_CHART_FILE} (version {version})")
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {HELM_CHART_FILE}: {e}")
|
print(f"Error processing {HELM_CHART_FILE}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def update_readme(version_with_v: str):
|
def update_readme(version_with_v: str):
|
||||||
"""Updates Docker image tags in README.md"""
|
"""Updates Docker image tags in README.md"""
|
||||||
if not README_FILE.exists():
|
if not README_FILE.exists():
|
||||||
|
@ -133,7 +164,7 @@ def update_readme(version_with_v: str):
|
||||||
try:
|
try:
|
||||||
content = README_FILE.read_text()
|
content = README_FILE.read_text()
|
||||||
# Regex to find and capture current ghcr.io/.../kokoro-fastapi-(cpu|gpu):vX.Y.Z
|
# Regex to find and capture current ghcr.io/.../kokoro-fastapi-(cpu|gpu):vX.Y.Z
|
||||||
pattern = r'(ghcr\.io/remsky/kokoro-fastapi-(?:cpu|gpu)):(v\d+\.\d+\.\d+)'
|
pattern = r"(ghcr\.io/remsky/kokoro-fastapi-(?:cpu|gpu)):(v\d+\.\d+\.\d+)"
|
||||||
matches = list(re.finditer(pattern, content)) # Find all occurrences
|
matches = list(re.finditer(pattern, content)) # Find all occurrences
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
|
@ -148,15 +179,19 @@ def update_readme(version_with_v: str):
|
||||||
|
|
||||||
if updated_needed:
|
if updated_needed:
|
||||||
# Perform replacement on all occurrences
|
# Perform replacement on all occurrences
|
||||||
new_content = re.sub(pattern, rf'\1:{version_with_v}', content)
|
new_content = re.sub(pattern, rf"\1:{version_with_v}", content)
|
||||||
README_FILE.write_text(new_content)
|
README_FILE.write_text(new_content)
|
||||||
print(f"Updated Docker image tags in {README_FILE} to {version_with_v}")
|
print(f"Updated Docker image tags in {README_FILE} to {version_with_v}")
|
||||||
else:
|
else:
|
||||||
print(f"Already up-to-date: Docker image tags in {README_FILE} (version {version_with_v})")
|
print(
|
||||||
|
f"Already up-to-date: Docker image tags in {README_FILE} (version {version_with_v})"
|
||||||
|
)
|
||||||
|
|
||||||
# Check for ':latest' tag usage remains the same
|
# Check for ':latest' tag usage remains the same
|
||||||
if ':latest' in content:
|
if ":latest" in content:
|
||||||
print(f"Warning: Found ':latest' tag in {README_FILE}. Consider updating manually if needed.")
|
print(
|
||||||
|
f"Warning: Found ':latest' tag in {README_FILE}. Consider updating manually if needed."
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {README_FILE}: {e}")
|
print(f"Error processing {README_FILE}: {e}")
|
||||||
|
@ -171,7 +206,9 @@ def main():
|
||||||
try:
|
try:
|
||||||
version = VERSION_FILE.read_text().strip()
|
version = VERSION_FILE.read_text().strip()
|
||||||
if not re.match(r"^\d+\.\d+\.\d+$", version):
|
if not re.match(r"^\d+\.\d+\.\d+$", version):
|
||||||
print(f"Error: Invalid version format '{version}' in {VERSION_FILE}. Expected X.Y.Z")
|
print(
|
||||||
|
f"Error: Invalid version format '{version}' in {VERSION_FILE}. Expected X.Y.Z"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading {VERSION_FILE}: {e}")
|
print(f"Error reading {VERSION_FILE}: {e}")
|
||||||
|
@ -192,5 +229,6 @@ def main():
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
print("Version update script finished.")
|
print("Version update script finished.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -31,8 +31,11 @@ async def mock_tts_service(mock_model_manager, mock_voice_manager):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service):
|
async def setup_mocks(
|
||||||
|
monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service
|
||||||
|
):
|
||||||
"""Setup global mocks for UI tests"""
|
"""Setup global mocks for UI tests"""
|
||||||
|
|
||||||
async def mock_get_model():
|
async def mock_get_model():
|
||||||
return mock_model_manager
|
return mock_model_manager
|
||||||
|
|
||||||
|
@ -44,4 +47,6 @@ async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_
|
||||||
|
|
||||||
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
|
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
|
||||||
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
|
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
|
||||||
monkeypatch.setattr("api.src.services.tts_service.TTSService.create", mock_create_service)
|
monkeypatch.setattr(
|
||||||
|
"api.src.services.tts_service.TTSService.create", mock_create_service
|
||||||
|
)
|
||||||
|
|
|
@ -59,9 +59,11 @@ def test_check_api_status_connection_error():
|
||||||
|
|
||||||
def test_text_to_speech_success(mock_response, tmp_path):
|
def test_text_to_speech_success(mock_response, tmp_path):
|
||||||
"""Test successful speech generation"""
|
"""Test successful speech generation"""
|
||||||
with patch("requests.post", return_value=mock_response({})), patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post", return_value=mock_response({})),
|
||||||
), patch("builtins.open", mock_open()) as mock_file:
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()) as mock_file,
|
||||||
|
):
|
||||||
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
@ -116,9 +118,11 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_voice, expected_voice in test_cases:
|
for input_voice, expected_voice in test_cases:
|
||||||
with patch("requests.post") as mock_post, patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post") as mock_post,
|
||||||
), patch("builtins.open", mock_open()):
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()),
|
||||||
|
):
|
||||||
mock_post.return_value = mock_response({})
|
mock_post.return_value = mock_response({})
|
||||||
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
||||||
|
|
||||||
|
@ -149,11 +153,15 @@ def test_text_to_speech_output_filename(mock_response, tmp_path):
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_voice, filename_check in test_cases:
|
for input_voice, filename_check in test_cases:
|
||||||
with patch("requests.post", return_value=mock_response({})), patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post", return_value=mock_response({})),
|
||||||
), patch("builtins.open", mock_open()) as mock_file:
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()) as mock_file,
|
||||||
|
):
|
||||||
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
|
assert filename_check(result), (
|
||||||
|
f"Expected voice pattern not found in filename: {result}"
|
||||||
|
)
|
||||||
mock_file.assert_called_once()
|
mock_file.assert_called_once()
|
||||||
|
|
|
@ -15,8 +15,9 @@ def mock_dirs(tmp_path):
|
||||||
inputs_dir.mkdir()
|
inputs_dir.mkdir()
|
||||||
outputs_dir.mkdir()
|
outputs_dir.mkdir()
|
||||||
|
|
||||||
with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
|
with (
|
||||||
"ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
|
patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)),
|
||||||
|
patch("ui.lib.files.OUTPUTS_DIR", str(outputs_dir)),
|
||||||
):
|
):
|
||||||
yield inputs_dir, outputs_dir
|
yield inputs_dir, outputs_dir
|
||||||
|
|
||||||
|
|
|
@ -62,8 +62,9 @@ def test_interface_html_links():
|
||||||
def test_update_status_available(mock_timer):
|
def test_update_status_available(mock_timer):
|
||||||
"""Test status update when service is available"""
|
"""Test status update when service is available"""
|
||||||
voices = ["voice1", "voice2"]
|
voices = ["voice1", "voice2"]
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(True, voices)),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
|
@ -81,8 +82,9 @@ def test_update_status_available(mock_timer):
|
||||||
|
|
||||||
def test_update_status_unavailable(mock_timer):
|
def test_update_status_unavailable(mock_timer):
|
||||||
"""Test status update when service is unavailable"""
|
"""Test status update when service is unavailable"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(False, [])),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
@ -97,9 +99,10 @@ def test_update_status_unavailable(mock_timer):
|
||||||
|
|
||||||
def test_update_status_error(mock_timer):
|
def test_update_status_error(mock_timer):
|
||||||
"""Test status update when an error occurs"""
|
"""Test status update when an error occurs"""
|
||||||
with patch(
|
with (
|
||||||
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
|
patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")),
|
||||||
), patch("gradio.Timer", return_value=mock_timer):
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
|
@ -113,8 +116,9 @@ def test_update_status_error(mock_timer):
|
||||||
|
|
||||||
def test_timer_configuration(mock_timer):
|
def test_timer_configuration(mock_timer):
|
||||||
"""Test timer configuration"""
|
"""Test timer configuration"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(False, [])),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,7 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always show file upload but handle differently based on disable_local_saving
|
# Always show file upload but handle differently based on disable_local_saving
|
||||||
file_upload = gr.File(
|
file_upload = gr.File(label="Upload Text File (.txt)", file_types=[".txt"])
|
||||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not disable_local_saving:
|
if not disable_local_saving:
|
||||||
# Show full interface with tabs when saving is enabled
|
# Show full interface with tabs when saving is enabled
|
||||||
|
@ -24,7 +22,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
tabs.selected = 0
|
tabs.selected = 0
|
||||||
# Direct Input Tab
|
# Direct Input Tab
|
||||||
with gr.TabItem("Direct Input"):
|
with gr.TabItem("Direct Input"):
|
||||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
text_submit_direct = gr.Button(
|
||||||
|
"Generate Speech", variant="primary", size="lg"
|
||||||
|
)
|
||||||
|
|
||||||
# File Input Tab
|
# File Input Tab
|
||||||
with gr.TabItem("From File"):
|
with gr.TabItem("From File"):
|
||||||
|
@ -48,7 +48,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Just show the generate button when saving is disabled
|
# Just show the generate button when saving is disabled
|
||||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
text_submit_direct = gr.Button(
|
||||||
|
"Generate Speech", variant="primary", size="lg"
|
||||||
|
)
|
||||||
tabs = None
|
tabs = None
|
||||||
input_files_list = None
|
input_files_list = None
|
||||||
file_preview = None
|
file_preview = None
|
||||||
|
|
|
@ -12,7 +12,7 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
audio_output = gr.Audio(
|
audio_output = gr.Audio(
|
||||||
label="Generated Speech",
|
label="Generated Speech",
|
||||||
type="filepath",
|
type="filepath",
|
||||||
waveform_options={"waveform_color": "#4C87AB"}
|
waveform_options={"waveform_color": "#4C87AB"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create file-related components with visible=False when local saving is disabled
|
# Create file-related components with visible=False when local saving is disabled
|
||||||
|
|
|
@ -58,17 +58,21 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
|
|
||||||
def handle_file_upload(file):
|
def handle_file_upload(file):
|
||||||
if file is None:
|
if file is None:
|
||||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
return (
|
||||||
|
""
|
||||||
|
if disable_local_saving
|
||||||
|
else [gr.update(choices=files.list_input_files())]
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read the file content
|
# Read the file content
|
||||||
with open(file.name, 'r', encoding='utf-8') as f:
|
with open(file.name, "r", encoding="utf-8") as f:
|
||||||
text_content = f.read()
|
text_content = f.read()
|
||||||
|
|
||||||
if disable_local_saving:
|
if disable_local_saving:
|
||||||
# When saving is disabled, put content directly in text input
|
# When saving is disabled, put content directly in text input
|
||||||
# Normalize whitespace by replacing newlines with spaces
|
# Normalize whitespace by replacing newlines with spaces
|
||||||
normalized_text = ' '.join(text_content.split())
|
normalized_text = " ".join(text_content.split())
|
||||||
return normalized_text
|
return normalized_text
|
||||||
else:
|
else:
|
||||||
# When saving is enabled, save file and update dropdown
|
# When saving is enabled, save file and update dropdown
|
||||||
|
@ -88,7 +92,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling file: {e}")
|
print(f"Error handling file: {e}")
|
||||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
return (
|
||||||
|
""
|
||||||
|
if disable_local_saving
|
||||||
|
else [gr.update(choices=files.list_input_files())]
|
||||||
|
)
|
||||||
|
|
||||||
def generate_from_text(text, voice, format, speed):
|
def generate_from_text(text, voice, format, speed):
|
||||||
"""Generate speech from direct text input"""
|
"""Generate speech from direct text input"""
|
||||||
|
@ -203,7 +211,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
components["input"]["file_upload"].upload(
|
components["input"]["file_upload"].upload(
|
||||||
fn=handle_file_upload,
|
fn=handle_file_upload,
|
||||||
inputs=[components["input"]["file_upload"]],
|
inputs=[components["input"]["file_upload"]],
|
||||||
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
|
outputs=[
|
||||||
|
components["input"]["text_input"]
|
||||||
|
if disable_local_saving
|
||||||
|
else components["input"]["file_select"]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if components["output"]["play_btn"] is not None:
|
if components["output"]["play_btn"] is not None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue