mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add Gradio web interface + tests
This commit is contained in:
parent
19321eabb2
commit
e749b3bc88
21 changed files with 1048 additions and 120 deletions
|
@ -1,5 +1,7 @@
|
||||||
[run]
|
[run]
|
||||||
source = api
|
source =
|
||||||
|
api
|
||||||
|
ui
|
||||||
omit =
|
omit =
|
||||||
Kokoro-82M/*
|
Kokoro-82M/*
|
||||||
MagicMock/*
|
MagicMock/*
|
||||||
|
|
|
@ -4,6 +4,10 @@ Notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
## 2024-01-09
|
## 2024-01-09
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Gradio Web Interface:
|
||||||
|
- Added simple web UI utility for audio generation from input or txt file
|
||||||
|
|
||||||
### Modified
|
### Modified
|
||||||
#### Configuration Changes
|
#### Configuration Changes
|
||||||
- Updated Docker configurations:
|
- Updated Docker configurations:
|
||||||
|
|
131
README.md
131
README.md
|
@ -3,42 +3,55 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# Kokoro TTS API
|
# Kokoro TTS API
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
||||||
|
|
||||||
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
|
- OpenAI-compatible Speech endpoint, with voice combination functionality
|
||||||
- NVIDIA GPU accelerated inference (or CPU) option
|
- NVIDIA GPU accelerated inference (or CPU) option
|
||||||
|
- very fast generation time (~35x real time factor)
|
||||||
- automatic chunking/stitching for long texts
|
- automatic chunking/stitching for long texts
|
||||||
- very fast generation time (~35-49x RTF)
|
- simple audio generation web ui utility
|
||||||
|
|
||||||
## Quick Start
|
<details open>
|
||||||
|
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
||||||
|
|
||||||
|
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||||
|
|
||||||
1. Install prerequisites:
|
1. Install prerequisites:
|
||||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
|
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
|
||||||
- Install [Git](https://git-scm.com/downloads) (or download and extract zip)
|
- Clone and start the service:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||||
|
cd Kokoro-FastAPI
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8880",
|
||||||
|
api_key="not-needed"
|
||||||
|
)
|
||||||
|
|
||||||
2. Clone and start the service:
|
response = client.audio.speech.create(
|
||||||
```bash
|
model="kokoro",
|
||||||
# Clone repository
|
voice="af_bella",
|
||||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
input="Hello world!",
|
||||||
cd Kokoro-FastAPI
|
response_format="mp3"
|
||||||
|
)
|
||||||
|
response.stream_to_file("output.mp3")
|
||||||
|
```
|
||||||
|
|
||||||
# For GPU acceleration (requires NVIDIA GPU):
|
or visit http://localhost:7860
|
||||||
docker compose up --build
|
<p align="center">
|
||||||
|
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||||
|
</p>
|
||||||
|
</details>
|
||||||
|
<details>
|
||||||
|
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
||||||
|
|
||||||
# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU):
|
|
||||||
docker compose -f docker-compose.cpu.yml up --build
|
|
||||||
```
|
|
||||||
Quick tests (run from another terminal):
|
|
||||||
```bash
|
|
||||||
# Test OpenAI Compatibility
|
|
||||||
python examples/test_openai_tts.py
|
|
||||||
# Test all available voices
|
|
||||||
python examples/test_all_voices.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## OpenAI-Compatible API
|
|
||||||
```python
|
```python
|
||||||
# Using OpenAI's Python library
|
# Using OpenAI's Python library
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -77,16 +90,26 @@ with open("output.mp3", "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Voice Combination
|
Quick tests (run from another terminal):
|
||||||
|
```bash
|
||||||
|
python examples/test_openai_tts.py # Test OpenAI Compatibility
|
||||||
|
python examples/test_all_voices.py # Test all available voices
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Voice Combination</b></summary>
|
||||||
|
|
||||||
Combine voices and generate audio:
|
Combine voices and generate audio:
|
||||||
```python
|
```python
|
||||||
import requests
|
import requests
|
||||||
|
response = requests.get("http://localhost:8880/v1/audio/voices")
|
||||||
|
voices = response.json()["voices"]
|
||||||
|
|
||||||
# Create combined voice (saved locally on server)
|
# Create combined voice (saves locally on server)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/v1/audio/voices/combine",
|
"http://localhost:8880/v1/audio/voices/combine",
|
||||||
json=["af_bella", "af_sarah"]
|
json=[voices[0], voices[1]]
|
||||||
)
|
)
|
||||||
combined_voice = response.json()["voice"]
|
combined_voice = response.json()["voice"]
|
||||||
|
|
||||||
|
@ -100,8 +123,27 @@ response = requests.post(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
<p align="center">
|
||||||
|
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||||
|
</p>
|
||||||
|
</details>
|
||||||
|
|
||||||
## Performance Benchmarks
|
<details>
|
||||||
|
<summary><b>Gradio Web Utility</b></summary>
|
||||||
|
|
||||||
|
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
|
||||||
|
- Voice/format/speed selection
|
||||||
|
- Audio playback and download
|
||||||
|
- Text file or direct input
|
||||||
|
|
||||||
|
If you only want the API, just comment out everything in the docker-compose.yml under and including `gradio-ui`
|
||||||
|
|
||||||
|
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Performance Benchmarks</b></summary>
|
||||||
|
|
||||||
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
||||||
- Windows 11 Home w/ WSL2
|
- Windows 11 Home w/ WSL2
|
||||||
|
@ -119,10 +161,22 @@ Benchmarking was performed on generation via the local API using text lengths up
|
||||||
Key Performance Metrics:
|
Key Performance Metrics:
|
||||||
- Realtime Factor: Ranges between 35-49x (generation time to output audio length)
|
- Realtime Factor: Ranges between 35-49x (generation time to output audio length)
|
||||||
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
||||||
|
</details>
|
||||||
|
<details>
|
||||||
|
<summary><b>GPU Vs. CPU<b></summary>
|
||||||
|
|
||||||
## Features
|
```bash
|
||||||
|
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
||||||
|
docker compose up --build
|
||||||
|
|
||||||
- OpenAI-compatible API endpoints
|
# CPU: ~10x slower than GPU inference
|
||||||
|
docker compose -f docker-compose.cpu.yml up --build
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
<details>
|
||||||
|
<summary><b>Features</b></summary>
|
||||||
|
|
||||||
|
- OpenAI-compatible API endpoints (with optional Gradio Web UI)
|
||||||
- GPU-accelerated inference (if desired)
|
- GPU-accelerated inference (if desired)
|
||||||
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
||||||
- Natural Boundary Detection:
|
- Natural Boundary Detection:
|
||||||
|
@ -131,19 +185,21 @@ Key Performance Metrics:
|
||||||
- Averages model weights of any existing voicepacks
|
- Averages model weights of any existing voicepacks
|
||||||
- Saves generated voicepacks for future use
|
- Saves generated voicepacks for future use
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||||
|
</details>
|
||||||
|
|
||||||
## Model
|
<details open>
|
||||||
|
<summary><b>Model</b></summary>
|
||||||
|
|
||||||
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
||||||
|
|
||||||
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
||||||
|
</details>
|
||||||
|
|
||||||
## License
|
<details>
|
||||||
|
<summary><b>License</b></summary>
|
||||||
|
|
||||||
This project is licensed under the Apache License 2.0 - see below for details:
|
This project is licensed under the Apache License 2.0 - see below for details:
|
||||||
|
|
||||||
|
@ -152,3 +208,4 @@ This project is licensed under the Apache License 2.0 - see below for details:
|
||||||
- The inference code adapted from StyleTTS2 is MIT licensed
|
- The inference code adapted from StyleTTS2 is MIT licensed
|
||||||
|
|
||||||
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
</details>
|
||||||
|
|
|
@ -46,14 +46,14 @@ services:
|
||||||
model-fetcher:
|
model-fetcher:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|
||||||
# # Gradio UI service
|
# Gradio UI service [Comment out everything below if you don't need it]
|
||||||
# gradio-ui:
|
gradio-ui:
|
||||||
# build:
|
build:
|
||||||
# context: ./ui
|
context: ./ui
|
||||||
# ports:
|
ports:
|
||||||
# - "7860:7860"
|
- "7860:7860"
|
||||||
# volumes:
|
volumes:
|
||||||
# - ./ui/data:/app/ui/data
|
- ./ui/data:/app/ui/data
|
||||||
# - ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
# environment:
|
environment:
|
||||||
# - GRADIO_WATCH=True # Enable hot reloading
|
- GRADIO_WATCH=True # Enable hot reloading
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
testpaths = api/tests
|
testpaths = api/tests ui/tests
|
||||||
python_files = test_*.py
|
python_files = test_*.py
|
||||||
addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
|
addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc
|
||||||
pythonpath = .
|
pythonpath = .
|
||||||
|
|
BIN
ui/GradioScreenShot.png
Normal file
BIN
ui/GradioScreenShot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 113 KiB |
|
@ -7,7 +7,11 @@ from .config import API_URL, OUTPUTS_DIR
|
||||||
def check_api_status() -> Tuple[bool, List[str]]:
|
def check_api_status() -> Tuple[bool, List[str]]:
|
||||||
"""Check TTS service status and get available voices."""
|
"""Check TTS service status and get available voices."""
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{API_URL}/v1/audio/voices", timeout=5)
|
# Use a longer timeout during startup
|
||||||
|
response = requests.get(
|
||||||
|
f"{API_URL}/v1/audio/voices",
|
||||||
|
timeout=30 # Increased timeout for initial startup period
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
voices = response.json().get("voices", [])
|
voices = response.json().get("voices", [])
|
||||||
if voices:
|
if voices:
|
||||||
|
@ -15,7 +19,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
||||||
print("No voices found in response")
|
print("No voices found in response")
|
||||||
return False, []
|
return False, []
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
print("API request timed out")
|
print("API request timed out (waiting for service startup)")
|
||||||
|
return False, []
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
print(f"Connection error (service may be starting up): {str(e)}")
|
||||||
return False, []
|
return False, []
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
print(f"API request failed: {str(e)}")
|
print(f"API request failed: {str(e)}")
|
||||||
|
|
|
@ -2,4 +2,4 @@ from .input import create_input_column
|
||||||
from .model import create_model_column
|
from .model import create_model_column
|
||||||
from .output import create_output_column
|
from .output import create_output_column
|
||||||
|
|
||||||
__all__ = ['create_input_column', 'create_model_column', 'create_output_column']
|
__all__ = ["create_input_column", "create_model_column", "create_output_column"]
|
||||||
|
|
|
@ -6,6 +6,8 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
"""Create the input column with text input and file handling."""
|
"""Create the input column with text input and file handling."""
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
|
# Set first tab as selected by default
|
||||||
|
tabs.selected = 0
|
||||||
# Direct Input Tab
|
# Direct Input Tab
|
||||||
with gr.TabItem("Direct Input"):
|
with gr.TabItem("Direct Input"):
|
||||||
text_input = gr.Textbox(
|
text_input = gr.Textbox(
|
||||||
|
@ -13,6 +15,11 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
placeholder="Enter text here...",
|
placeholder="Enter text here...",
|
||||||
lines=4
|
lines=4
|
||||||
)
|
)
|
||||||
|
text_submit = gr.Button(
|
||||||
|
"Generate Speech",
|
||||||
|
variant="primary",
|
||||||
|
size="lg"
|
||||||
|
)
|
||||||
|
|
||||||
# File Input Tab
|
# File Input Tab
|
||||||
with gr.TabItem("From File"):
|
with gr.TabItem("From File"):
|
||||||
|
@ -35,12 +42,27 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
lines=4
|
lines=4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
file_submit = gr.Button(
|
||||||
|
"Generate Speech",
|
||||||
|
variant="primary",
|
||||||
|
size="lg"
|
||||||
|
)
|
||||||
|
clear_files = gr.Button(
|
||||||
|
"Clear Files",
|
||||||
|
variant="secondary",
|
||||||
|
size="lg"
|
||||||
|
)
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"tabs": tabs,
|
"tabs": tabs,
|
||||||
"text_input": text_input,
|
"text_input": text_input,
|
||||||
"file_select": input_files_list,
|
"file_select": input_files_list,
|
||||||
"file_upload": file_upload,
|
"file_upload": file_upload,
|
||||||
"file_preview": file_preview
|
"file_preview": file_preview,
|
||||||
|
"text_submit": text_submit,
|
||||||
|
"file_submit": file_submit,
|
||||||
|
"clear_files": clear_files
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -10,10 +10,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
gr.Markdown("### Model Settings")
|
gr.Markdown("### Model Settings")
|
||||||
|
|
||||||
# Status button with embedded status
|
# Status button starts in waiting state
|
||||||
is_available, _ = api.check_api_status()
|
|
||||||
status_btn = gr.Button(
|
status_btn = gr.Button(
|
||||||
f"Checking TTS Service: {'Available' if is_available else 'Not Yet Available'}",
|
"⌛ TTS Service: Waiting for Service...",
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,18 +35,11 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
||||||
label="Speed"
|
label="Speed"
|
||||||
)
|
)
|
||||||
|
|
||||||
submit_btn = gr.Button(
|
|
||||||
"Generate Speech",
|
|
||||||
variant="primary",
|
|
||||||
size="lg"
|
|
||||||
)
|
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"status_btn": status_btn,
|
"status_btn": status_btn,
|
||||||
"voice": voice_input,
|
"voice": voice_input,
|
||||||
"format": format_input,
|
"format": format_input,
|
||||||
"speed": speed_input,
|
"speed": speed_input
|
||||||
"submit": submit_btn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -27,11 +27,14 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
||||||
visible=False
|
visible=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"audio_output": audio_output,
|
"audio_output": audio_output,
|
||||||
"output_files": output_files,
|
"output_files": output_files,
|
||||||
"play_btn": play_btn,
|
"play_btn": play_btn,
|
||||||
"selected_audio": selected_audio
|
"selected_audio": selected_audio,
|
||||||
|
"clear_outputs": clear_outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -56,6 +56,30 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
||||||
print(f"Error saving file: {e}")
|
print(f"Error saving file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def delete_all_input_files() -> bool:
|
||||||
|
"""Delete all files from the inputs directory. Returns True if successful."""
|
||||||
|
try:
|
||||||
|
for filename in os.listdir(INPUTS_DIR):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
file_path = os.path.join(INPUTS_DIR, filename)
|
||||||
|
os.remove(file_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting input files: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_all_output_files() -> bool:
|
||||||
|
"""Delete all audio files from the outputs directory. Returns True if successful."""
|
||||||
|
try:
|
||||||
|
for filename in os.listdir(OUTPUTS_DIR):
|
||||||
|
if any(filename.endswith(ext) for ext in AUDIO_FORMATS):
|
||||||
|
file_path = os.path.join(OUTPUTS_DIR, filename)
|
||||||
|
os.remove(file_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting output files: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def process_uploaded_file(file_path: str) -> bool:
|
def process_uploaded_file(file_path: str) -> bool:
|
||||||
"""Save uploaded file to inputs directory. Returns True if successful."""
|
"""Save uploaded file to inputs directory. Returns True if successful."""
|
||||||
if not file_path:
|
if not file_path:
|
||||||
|
|
|
@ -7,19 +7,40 @@ def setup_event_handlers(components: dict):
|
||||||
"""Set up all event handlers for the UI components."""
|
"""Set up all event handlers for the UI components."""
|
||||||
|
|
||||||
def refresh_status():
|
def refresh_status():
|
||||||
is_available, voices = api.check_api_status()
|
try:
|
||||||
status = "Available" if is_available else "Unavailable"
|
is_available, voices = api.check_api_status()
|
||||||
btn_text = f"🔄 TTS Service: {status}"
|
status = "Available" if is_available else "Waiting for Service..."
|
||||||
|
|
||||||
if is_available and voices:
|
if is_available and voices:
|
||||||
return {
|
# Preserve current voice selection if it exists and is still valid
|
||||||
components["model"]["status_btn"]: gr.update(value=btn_text),
|
current_voice = components["model"]["voice"].value
|
||||||
components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
|
default_voice = current_voice if current_voice in voices else voices[0]
|
||||||
}
|
return [
|
||||||
return {
|
gr.update(
|
||||||
components["model"]["status_btn"]: gr.update(value=btn_text),
|
value=f"🔄 TTS Service: {status}",
|
||||||
components["model"]["voice"]: gr.update(choices=[], value=None)
|
interactive=True,
|
||||||
}
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=voices, value=default_voice)
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
gr.update(
|
||||||
|
value=f"⌛ TTS Service: {status}",
|
||||||
|
interactive=True,
|
||||||
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=[], value=None)
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in refresh status: {str(e)}")
|
||||||
|
return [
|
||||||
|
gr.update(
|
||||||
|
value="❌ TTS Service: Connection Error",
|
||||||
|
interactive=True,
|
||||||
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=[], value=None)
|
||||||
|
]
|
||||||
|
|
||||||
def handle_file_select(filename):
|
def handle_file_select(filename):
|
||||||
if filename:
|
if filename:
|
||||||
|
@ -56,46 +77,96 @@ def setup_event_handlers(components: dict):
|
||||||
|
|
||||||
return gr.update(choices=files.list_input_files())
|
return gr.update(choices=files.list_input_files())
|
||||||
|
|
||||||
def generate_speech(text, selected_file, voice, format, speed):
|
def generate_from_text(text, voice, format, speed):
|
||||||
|
"""Generate speech from direct text input"""
|
||||||
is_available, _ = api.check_api_status()
|
is_available, _ = api.check_api_status()
|
||||||
if not is_available:
|
if not is_available:
|
||||||
gr.Warning("TTS Service is currently unavailable")
|
gr.Warning("TTS Service is currently unavailable")
|
||||||
return {
|
return [
|
||||||
components["output"]["audio_output"]: None,
|
None,
|
||||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
gr.update(choices=files.list_output_files())
|
||||||
}
|
]
|
||||||
|
|
||||||
# Use text input if provided, otherwise use file content
|
if not text or not text.strip():
|
||||||
if text and text.strip():
|
gr.Warning("Please enter text in the input box")
|
||||||
files.save_text(text)
|
return [
|
||||||
final_text = text
|
None,
|
||||||
elif selected_file:
|
gr.update(choices=files.list_output_files())
|
||||||
final_text = files.read_text_file(selected_file)
|
]
|
||||||
else:
|
|
||||||
gr.Warning("Please enter text or select a file")
|
|
||||||
return {
|
|
||||||
components["output"]["audio_output"]: None,
|
|
||||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
|
||||||
}
|
|
||||||
|
|
||||||
result = api.text_to_speech(final_text, voice, format, speed)
|
files.save_text(text)
|
||||||
|
result = api.text_to_speech(text, voice, format, speed)
|
||||||
if result is None:
|
if result is None:
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return {
|
return [
|
||||||
components["output"]["audio_output"]: None,
|
None,
|
||||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
gr.update(choices=files.list_output_files())
|
||||||
}
|
]
|
||||||
|
|
||||||
return {
|
return [
|
||||||
components["output"]["audio_output"]: result,
|
result,
|
||||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||||
}
|
]
|
||||||
|
|
||||||
|
def generate_from_file(selected_file, voice, format, speed):
|
||||||
|
"""Generate speech from selected file"""
|
||||||
|
is_available, _ = api.check_api_status()
|
||||||
|
if not is_available:
|
||||||
|
gr.Warning("TTS Service is currently unavailable")
|
||||||
|
return [
|
||||||
|
None,
|
||||||
|
gr.update(choices=files.list_output_files())
|
||||||
|
]
|
||||||
|
|
||||||
|
if not selected_file:
|
||||||
|
gr.Warning("Please select a file")
|
||||||
|
return [
|
||||||
|
None,
|
||||||
|
gr.update(choices=files.list_output_files())
|
||||||
|
]
|
||||||
|
|
||||||
|
text = files.read_text_file(selected_file)
|
||||||
|
result = api.text_to_speech(text, voice, format, speed)
|
||||||
|
if result is None:
|
||||||
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
|
return [
|
||||||
|
None,
|
||||||
|
gr.update(choices=files.list_output_files())
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
result,
|
||||||
|
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||||
|
]
|
||||||
|
|
||||||
def play_selected(file_path):
|
def play_selected(file_path):
|
||||||
if file_path and os.path.exists(file_path):
|
if file_path and os.path.exists(file_path):
|
||||||
return gr.update(value=file_path, visible=True)
|
return gr.update(value=file_path, visible=True)
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
def clear_files(voice, format, speed):
|
||||||
|
"""Delete all input files and clear UI components while preserving model settings"""
|
||||||
|
files.delete_all_input_files()
|
||||||
|
return [
|
||||||
|
gr.update(value=None, choices=[]), # file_select
|
||||||
|
None, # file_upload
|
||||||
|
gr.update(value=""), # file_preview
|
||||||
|
None, # audio_output
|
||||||
|
gr.update(choices=files.list_output_files()), # output_files
|
||||||
|
gr.update(value=voice), # voice
|
||||||
|
gr.update(value=format), # format
|
||||||
|
gr.update(value=speed) # speed
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear_outputs():
|
||||||
|
"""Delete all output audio files and clear audio components"""
|
||||||
|
files.delete_all_output_files()
|
||||||
|
return [
|
||||||
|
None, # audio_output
|
||||||
|
gr.update(choices=[], value=None), # output_files
|
||||||
|
gr.update(visible=False) # selected_audio
|
||||||
|
]
|
||||||
|
|
||||||
# Connect event handlers
|
# Connect event handlers
|
||||||
components["model"]["status_btn"].click(
|
components["model"]["status_btn"].click(
|
||||||
fn=refresh_status,
|
fn=refresh_status,
|
||||||
|
@ -123,10 +194,54 @@ def setup_event_handlers(components: dict):
|
||||||
outputs=[components["output"]["selected_audio"]]
|
outputs=[components["output"]["selected_audio"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
components["model"]["submit"].click(
|
# Connect clear files button
|
||||||
fn=generate_speech,
|
components["input"]["clear_files"].click(
|
||||||
|
fn=clear_files,
|
||||||
|
inputs=[
|
||||||
|
components["model"]["voice"],
|
||||||
|
components["model"]["format"],
|
||||||
|
components["model"]["speed"]
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
components["input"]["file_select"],
|
||||||
|
components["input"]["file_upload"],
|
||||||
|
components["input"]["file_preview"],
|
||||||
|
components["output"]["audio_output"],
|
||||||
|
components["output"]["output_files"],
|
||||||
|
components["model"]["voice"],
|
||||||
|
components["model"]["format"],
|
||||||
|
components["model"]["speed"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect submit buttons for each tab
|
||||||
|
components["input"]["text_submit"].click(
|
||||||
|
fn=generate_from_text,
|
||||||
inputs=[
|
inputs=[
|
||||||
components["input"]["text_input"],
|
components["input"]["text_input"],
|
||||||
|
components["model"]["voice"],
|
||||||
|
components["model"]["format"],
|
||||||
|
components["model"]["speed"]
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
components["output"]["audio_output"],
|
||||||
|
components["output"]["output_files"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect clear outputs button
|
||||||
|
components["output"]["clear_outputs"].click(
|
||||||
|
fn=clear_outputs,
|
||||||
|
outputs=[
|
||||||
|
components["output"]["audio_output"],
|
||||||
|
components["output"]["output_files"],
|
||||||
|
components["output"]["selected_audio"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
components["input"]["file_submit"].click(
|
||||||
|
fn=generate_from_file,
|
||||||
|
inputs=[
|
||||||
components["input"]["file_select"],
|
components["input"]["file_select"],
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
components["model"]["format"],
|
components["model"]["format"],
|
||||||
|
|
|
@ -5,8 +5,8 @@ from .handlers import setup_event_handlers
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
"""Create the main Gradio interface."""
|
"""Create the main Gradio interface."""
|
||||||
# Initial status check
|
# Skip initial status check - let the timer handle it
|
||||||
is_available, available_voices = api.check_api_status()
|
is_available, available_voices = False, []
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
title="Kokoro TTS Demo",
|
title="Kokoro TTS Demo",
|
||||||
|
@ -36,19 +36,55 @@ def create_interface():
|
||||||
|
|
||||||
# Add periodic status check with Timer
|
# Add periodic status check with Timer
|
||||||
def update_status():
|
def update_status():
|
||||||
is_available, voices = api.check_api_status()
|
try:
|
||||||
status = "Available" if is_available else "Unavailable"
|
is_available, voices = api.check_api_status()
|
||||||
return {
|
status = "Available" if is_available else "Waiting for Service..."
|
||||||
components["model"]["status_btn"]: gr.update(value=f"🔄 TTS Service: {status}"),
|
|
||||||
components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
|
|
||||||
}
|
|
||||||
|
|
||||||
timer = gr.Timer(10, active=True) # Check every 10 seconds
|
if is_available and voices:
|
||||||
|
# Service is available, update UI and stop timer
|
||||||
|
current_voice = components["model"]["voice"].value
|
||||||
|
default_voice = current_voice if current_voice in voices else voices[0]
|
||||||
|
# Return values in same order as outputs list
|
||||||
|
return [
|
||||||
|
gr.update(
|
||||||
|
value=f"🔄 TTS Service: {status}",
|
||||||
|
interactive=True,
|
||||||
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=voices, value=default_voice),
|
||||||
|
gr.update(active=False) # Stop timer
|
||||||
|
]
|
||||||
|
|
||||||
|
# Service not available yet, keep checking
|
||||||
|
return [
|
||||||
|
gr.update(
|
||||||
|
value=f"⌛ TTS Service: {status}",
|
||||||
|
interactive=True,
|
||||||
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=[], value=None),
|
||||||
|
gr.update(active=True)
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in status update: {str(e)}")
|
||||||
|
# On error, keep the timer running but show error state
|
||||||
|
return [
|
||||||
|
gr.update(
|
||||||
|
value="❌ TTS Service: Connection Error",
|
||||||
|
interactive=True,
|
||||||
|
variant="secondary"
|
||||||
|
),
|
||||||
|
gr.update(choices=[], value=None),
|
||||||
|
gr.update(active=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
timer = gr.Timer(value=5) # Check every 5 seconds
|
||||||
timer.tick(
|
timer.tick(
|
||||||
fn=update_status,
|
fn=update_status,
|
||||||
outputs=[
|
outputs=[
|
||||||
components["model"]["status_btn"],
|
components["model"]["status_btn"],
|
||||||
components["model"]["voice"]
|
components["model"]["voice"],
|
||||||
|
timer
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
9
ui/tests/conftest.py
Normal file
9
ui/tests/conftest.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
import pytest
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gr_context():
|
||||||
|
"""Provides a context for testing Gradio components"""
|
||||||
|
with gr.Blocks():
|
||||||
|
yield
|
129
ui/tests/test_api.py
Normal file
129
ui/tests/test_api.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from unittest.mock import patch, mock_open
|
||||||
|
from ui.lib import api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response():
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, json_data, status_code=200, content=b"audio data"):
|
||||||
|
self._json = json_data
|
||||||
|
self.status_code = status_code
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code != 200:
|
||||||
|
raise requests.exceptions.HTTPError(f"HTTP {self.status_code}")
|
||||||
|
|
||||||
|
return MockResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_api_status_success(mock_response):
|
||||||
|
"""Test successful API status check"""
|
||||||
|
mock_data = {"voices": ["voice1", "voice2"]}
|
||||||
|
with patch("requests.get", return_value=mock_response(mock_data)):
|
||||||
|
status, voices = api.check_api_status()
|
||||||
|
assert status is True
|
||||||
|
assert voices == ["voice1", "voice2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_api_status_no_voices(mock_response):
|
||||||
|
"""Test API response with no voices"""
|
||||||
|
with patch("requests.get", return_value=mock_response({"voices": []})):
|
||||||
|
status, voices = api.check_api_status()
|
||||||
|
assert status is False
|
||||||
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_api_status_timeout():
|
||||||
|
"""Test API timeout"""
|
||||||
|
with patch("requests.get", side_effect=requests.exceptions.Timeout):
|
||||||
|
status, voices = api.check_api_status()
|
||||||
|
assert status is False
|
||||||
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_api_status_connection_error():
|
||||||
|
"""Test API connection error"""
|
||||||
|
with patch("requests.get", side_effect=requests.exceptions.ConnectionError):
|
||||||
|
status, voices = api.check_api_status()
|
||||||
|
assert status is False
|
||||||
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_success(mock_response, tmp_path):
|
||||||
|
"""Test successful speech generation"""
|
||||||
|
with patch("requests.post", return_value=mock_response({})), \
|
||||||
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||||
|
patch("builtins.open", mock_open()) as mock_file:
|
||||||
|
|
||||||
|
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "output_" in result
|
||||||
|
assert result.endswith(".mp3")
|
||||||
|
mock_file.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_empty_text():
|
||||||
|
"""Test speech generation with empty text"""
|
||||||
|
result = api.text_to_speech("", "voice1", "mp3", 1.0)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_timeout():
|
||||||
|
"""Test speech generation timeout"""
|
||||||
|
with patch("requests.post", side_effect=requests.exceptions.Timeout):
|
||||||
|
result = api.text_to_speech("test", "voice1", "mp3", 1.0)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_request_error():
|
||||||
|
"""Test speech generation request error"""
|
||||||
|
with patch("requests.post", side_effect=requests.exceptions.RequestException):
|
||||||
|
result = api.text_to_speech("test", "voice1", "mp3", 1.0)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_status_html_available():
|
||||||
|
"""Test status HTML generation for available service"""
|
||||||
|
html = api.get_status_html(True)
|
||||||
|
assert "green" in html
|
||||||
|
assert "Available" in html
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_status_html_unavailable():
|
||||||
|
"""Test status HTML generation for unavailable service"""
|
||||||
|
html = api.get_status_html(False)
|
||||||
|
assert "red" in html
|
||||||
|
assert "Unavailable" in html
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
|
"""Test correct API parameters are sent"""
|
||||||
|
with patch("requests.post") as mock_post, \
|
||||||
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||||
|
patch("builtins.open", mock_open()):
|
||||||
|
|
||||||
|
mock_post.return_value = mock_response({})
|
||||||
|
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
args, kwargs = mock_post.call_args
|
||||||
|
|
||||||
|
# Check request body
|
||||||
|
assert kwargs["json"] == {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "test text",
|
||||||
|
"voice": "voice1",
|
||||||
|
"response_format": "mp3",
|
||||||
|
"speed": 1.5
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check headers and timeout
|
||||||
|
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||||
|
assert kwargs["timeout"] == 300
|
116
ui/tests/test_components.py
Normal file
116
ui/tests/test_components.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import pytest
|
||||||
|
import gradio as gr
|
||||||
|
from ui.lib.components.model import create_model_column
|
||||||
|
from ui.lib.components.output import create_output_column
|
||||||
|
from ui.lib.config import AUDIO_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_model_column_structure():
|
||||||
|
"""Test that create_model_column returns the expected structure"""
|
||||||
|
voice_ids = ["voice1", "voice2"]
|
||||||
|
column, components = create_model_column(voice_ids)
|
||||||
|
|
||||||
|
# Test return types
|
||||||
|
assert isinstance(column, gr.Column)
|
||||||
|
assert isinstance(components, dict)
|
||||||
|
|
||||||
|
# Test expected components presence
|
||||||
|
expected_components = {
|
||||||
|
"status_btn",
|
||||||
|
"voice",
|
||||||
|
"format",
|
||||||
|
"speed"
|
||||||
|
}
|
||||||
|
assert set(components.keys()) == expected_components
|
||||||
|
|
||||||
|
# Test component types
|
||||||
|
assert isinstance(components["status_btn"], gr.Button)
|
||||||
|
assert isinstance(components["voice"], gr.Dropdown)
|
||||||
|
assert isinstance(components["format"], gr.Dropdown)
|
||||||
|
assert isinstance(components["speed"], gr.Slider)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_column_default_values():
|
||||||
|
"""Test the default values of model column components"""
|
||||||
|
voice_ids = ["voice1", "voice2"]
|
||||||
|
_, components = create_model_column(voice_ids)
|
||||||
|
|
||||||
|
# Test voice dropdown
|
||||||
|
# Gradio Dropdown converts choices to (value, label) tuples
|
||||||
|
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
|
||||||
|
assert components["voice"].choices == expected_choices
|
||||||
|
# Value is not converted to tuple format for the value property
|
||||||
|
assert components["voice"].value == voice_ids[0]
|
||||||
|
assert components["voice"].interactive is True
|
||||||
|
|
||||||
|
# Test format dropdown
|
||||||
|
# Gradio Dropdown converts choices to (value, label) tuples
|
||||||
|
expected_format_choices = [(fmt, fmt) for fmt in AUDIO_FORMATS]
|
||||||
|
assert components["format"].choices == expected_format_choices
|
||||||
|
assert components["format"].value == "mp3"
|
||||||
|
|
||||||
|
# Test speed slider
|
||||||
|
assert components["speed"].minimum == 0.5
|
||||||
|
assert components["speed"].maximum == 2.0
|
||||||
|
assert components["speed"].value == 1.0
|
||||||
|
assert components["speed"].step == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_column_no_voices():
|
||||||
|
"""Test model column creation with no voice IDs"""
|
||||||
|
_, components = create_model_column()
|
||||||
|
|
||||||
|
assert components["voice"].choices == []
|
||||||
|
assert components["voice"].value is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_output_column_structure():
|
||||||
|
"""Test that create_output_column returns the expected structure"""
|
||||||
|
column, components = create_output_column()
|
||||||
|
|
||||||
|
# Test return types
|
||||||
|
assert isinstance(column, gr.Column)
|
||||||
|
assert isinstance(components, dict)
|
||||||
|
|
||||||
|
# Test expected components presence
|
||||||
|
expected_components = {
|
||||||
|
"audio_output",
|
||||||
|
"output_files",
|
||||||
|
"play_btn",
|
||||||
|
"selected_audio",
|
||||||
|
"clear_outputs"
|
||||||
|
}
|
||||||
|
assert set(components.keys()) == expected_components
|
||||||
|
|
||||||
|
# Test component types
|
||||||
|
assert isinstance(components["audio_output"], gr.Audio)
|
||||||
|
assert isinstance(components["output_files"], gr.Dropdown)
|
||||||
|
assert isinstance(components["play_btn"], gr.Button)
|
||||||
|
assert isinstance(components["selected_audio"], gr.Audio)
|
||||||
|
assert isinstance(components["clear_outputs"], gr.Button)
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_column_configuration():
|
||||||
|
"""Test the configuration of output column components"""
|
||||||
|
_, components = create_output_column()
|
||||||
|
|
||||||
|
# Test audio output configuration
|
||||||
|
assert components["audio_output"].label == "Generated Speech"
|
||||||
|
assert components["audio_output"].type == "filepath"
|
||||||
|
|
||||||
|
# Test output files dropdown
|
||||||
|
assert components["output_files"].label == "Previous Outputs"
|
||||||
|
assert components["output_files"].allow_custom_value is False
|
||||||
|
|
||||||
|
# Test play button
|
||||||
|
assert components["play_btn"].value == "▶️ Play Selected"
|
||||||
|
assert components["play_btn"].size == "sm"
|
||||||
|
|
||||||
|
# Test selected audio configuration
|
||||||
|
assert components["selected_audio"].label == "Selected Output"
|
||||||
|
assert components["selected_audio"].type == "filepath"
|
||||||
|
assert components["selected_audio"].visible is False
|
||||||
|
|
||||||
|
# Test clear outputs button
|
||||||
|
assert components["clear_outputs"].size == "sm"
|
||||||
|
assert components["clear_outputs"].variant == "secondary"
|
195
ui/tests/test_files.py
Normal file
195
ui/tests/test_files.py
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from ui.lib import files
|
||||||
|
from ui.lib.config import AUDIO_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dirs(tmp_path):
|
||||||
|
"""Create temporary input and output directories"""
|
||||||
|
inputs_dir = tmp_path / "inputs"
|
||||||
|
outputs_dir = tmp_path / "outputs"
|
||||||
|
inputs_dir.mkdir()
|
||||||
|
outputs_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
|
||||||
|
"ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
|
||||||
|
):
|
||||||
|
yield inputs_dir, outputs_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_input_files_empty(mock_dirs):
|
||||||
|
"""Test listing input files from empty directory"""
|
||||||
|
assert files.list_input_files() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_input_files(mock_dirs):
|
||||||
|
"""Test listing input files with various files"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
(inputs_dir / "test1.txt").write_text("content1")
|
||||||
|
(inputs_dir / "test2.txt").write_text("content2")
|
||||||
|
(inputs_dir / "nottext.pdf").write_text("should not be listed")
|
||||||
|
|
||||||
|
result = files.list_input_files()
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "test1.txt" in result
|
||||||
|
assert "test2.txt" in result
|
||||||
|
assert "nottext.pdf" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_output_files_empty(mock_dirs):
|
||||||
|
"""Test listing output files from empty directory"""
|
||||||
|
assert files.list_output_files() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_output_files(mock_dirs):
|
||||||
|
"""Test listing output files with various formats"""
|
||||||
|
_, outputs_dir = mock_dirs
|
||||||
|
|
||||||
|
# Create test files for each format
|
||||||
|
for fmt in AUDIO_FORMATS:
|
||||||
|
(outputs_dir / f"test.{fmt}").write_text("dummy content")
|
||||||
|
(outputs_dir / "test.txt").write_text("should not be listed")
|
||||||
|
|
||||||
|
result = files.list_output_files()
|
||||||
|
assert len(result) == len(AUDIO_FORMATS)
|
||||||
|
for fmt in AUDIO_FORMATS:
|
||||||
|
assert any(f".{fmt}" in file for file in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_text_file_empty_filename(mock_dirs):
|
||||||
|
"""Test reading with empty filename"""
|
||||||
|
assert files.read_text_file("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_text_file_nonexistent(mock_dirs):
|
||||||
|
"""Test reading nonexistent file"""
|
||||||
|
assert files.read_text_file("nonexistent.txt") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_text_file_success(mock_dirs):
|
||||||
|
"""Test successful file reading"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
content = "Test content\nMultiple lines"
|
||||||
|
(inputs_dir / "test.txt").write_text(content)
|
||||||
|
|
||||||
|
assert files.read_text_file("test.txt") == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_text_empty(mock_dirs):
|
||||||
|
"""Test saving empty text"""
|
||||||
|
assert files.save_text("") is None
|
||||||
|
assert files.save_text(" ") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_text_auto_filename(mock_dirs):
|
||||||
|
"""Test saving text with auto-generated filename"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# First save
|
||||||
|
filename1 = files.save_text("content1")
|
||||||
|
assert filename1 == "input_1.txt"
|
||||||
|
assert (inputs_dir / filename1).read_text() == "content1"
|
||||||
|
|
||||||
|
# Second save
|
||||||
|
filename2 = files.save_text("content2")
|
||||||
|
assert filename2 == "input_2.txt"
|
||||||
|
assert (inputs_dir / filename2).read_text() == "content2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_text_custom_filename(mock_dirs):
|
||||||
|
"""Test saving text with custom filename"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
filename = files.save_text("content", "custom.txt")
|
||||||
|
assert filename == "custom.txt"
|
||||||
|
assert (inputs_dir / filename).read_text() == "content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_text_duplicate_filename(mock_dirs):
|
||||||
|
"""Test saving text with duplicate filename"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# First save
|
||||||
|
filename1 = files.save_text("content1", "test.txt")
|
||||||
|
assert filename1 == "test.txt"
|
||||||
|
|
||||||
|
# Save with same filename
|
||||||
|
filename2 = files.save_text("content2", "test.txt")
|
||||||
|
assert filename2 == "test_1.txt"
|
||||||
|
|
||||||
|
assert (inputs_dir / "test.txt").read_text() == "content1"
|
||||||
|
assert (inputs_dir / "test_1.txt").read_text() == "content2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_input_files(mock_dirs):
|
||||||
|
"""Test deleting all input files"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
(inputs_dir / "test1.txt").write_text("content1")
|
||||||
|
(inputs_dir / "test2.txt").write_text("content2")
|
||||||
|
(inputs_dir / "keep.pdf").write_text("should not be deleted")
|
||||||
|
|
||||||
|
assert files.delete_all_input_files() is True
|
||||||
|
remaining_files = list(inputs_dir.iterdir())
|
||||||
|
assert len(remaining_files) == 1
|
||||||
|
assert remaining_files[0].name == "keep.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_output_files(mock_dirs):
|
||||||
|
"""Test deleting all output files"""
|
||||||
|
_, outputs_dir = mock_dirs
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
for fmt in AUDIO_FORMATS:
|
||||||
|
(outputs_dir / f"test.{fmt}").write_text("dummy content")
|
||||||
|
(outputs_dir / "keep.txt").write_text("should not be deleted")
|
||||||
|
|
||||||
|
assert files.delete_all_output_files() is True
|
||||||
|
remaining_files = list(outputs_dir.iterdir())
|
||||||
|
assert len(remaining_files) == 1
|
||||||
|
assert remaining_files[0].name == "keep.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_uploaded_file_empty_path(mock_dirs):
|
||||||
|
"""Test processing empty file path"""
|
||||||
|
assert files.process_uploaded_file("") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_uploaded_file_invalid_extension(mock_dirs, tmp_path):
|
||||||
|
"""Test processing file with invalid extension"""
|
||||||
|
test_file = tmp_path / "test.pdf"
|
||||||
|
test_file.write_text("content")
|
||||||
|
assert files.process_uploaded_file(str(test_file)) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_uploaded_file_success(mock_dirs, tmp_path):
|
||||||
|
"""Test successful file upload processing"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# Create source file
|
||||||
|
source_file = tmp_path / "test.txt"
|
||||||
|
source_file.write_text("test content")
|
||||||
|
|
||||||
|
assert files.process_uploaded_file(str(source_file)) is True
|
||||||
|
assert (inputs_dir / "test.txt").read_text() == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_uploaded_file_duplicate(mock_dirs, tmp_path):
|
||||||
|
"""Test processing file with duplicate name"""
|
||||||
|
inputs_dir, _ = mock_dirs
|
||||||
|
|
||||||
|
# Create existing file
|
||||||
|
(inputs_dir / "test.txt").write_text("existing content")
|
||||||
|
|
||||||
|
# Create source file
|
||||||
|
source_file = tmp_path / "test.txt"
|
||||||
|
source_file.write_text("new content")
|
||||||
|
|
||||||
|
assert files.process_uploaded_file(str(source_file)) is True
|
||||||
|
assert (inputs_dir / "test.txt").read_text() == "existing content"
|
||||||
|
assert (inputs_dir / "test_1.txt").read_text() == "new content"
|
4
ui/tests/test_handlers.py
Normal file
4
ui/tests/test_handlers.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
"""
|
||||||
|
Drop all tests for now. The Gradio event system is too complex to test properly.
|
||||||
|
We'll need to find a better way to test the UI functionality.
|
||||||
|
"""
|
74
ui/tests/test_input.py
Normal file
74
ui/tests/test_input.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import pytest
|
||||||
|
import gradio as gr
|
||||||
|
from ui.lib.components.input import create_input_column
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_input_column_structure():
|
||||||
|
"""Test that create_input_column returns the expected structure"""
|
||||||
|
column, components = create_input_column()
|
||||||
|
|
||||||
|
# Test the return types
|
||||||
|
assert isinstance(column, gr.Column)
|
||||||
|
assert isinstance(components, dict)
|
||||||
|
|
||||||
|
# Test that all expected components are present
|
||||||
|
expected_components = {
|
||||||
|
"tabs",
|
||||||
|
"text_input",
|
||||||
|
"file_select",
|
||||||
|
"file_upload",
|
||||||
|
"file_preview",
|
||||||
|
"text_submit",
|
||||||
|
"file_submit",
|
||||||
|
"clear_files",
|
||||||
|
}
|
||||||
|
assert set(components.keys()) == expected_components
|
||||||
|
|
||||||
|
# Test component types
|
||||||
|
assert isinstance(components["tabs"], gr.Tabs)
|
||||||
|
assert isinstance(components["text_input"], gr.Textbox)
|
||||||
|
assert isinstance(components["file_select"], gr.Dropdown)
|
||||||
|
assert isinstance(components["file_upload"], gr.File)
|
||||||
|
assert isinstance(components["file_preview"], gr.Textbox)
|
||||||
|
assert isinstance(components["text_submit"], gr.Button)
|
||||||
|
assert isinstance(components["file_submit"], gr.Button)
|
||||||
|
assert isinstance(components["clear_files"], gr.Button)
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_input_configuration():
|
||||||
|
"""Test the text input component configuration"""
|
||||||
|
_, components = create_input_column()
|
||||||
|
text_input = components["text_input"]
|
||||||
|
|
||||||
|
assert text_input.label == "Text to speak"
|
||||||
|
assert text_input.placeholder == "Enter text here..."
|
||||||
|
assert text_input.lines == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_upload_configuration():
|
||||||
|
"""Test the file upload component configuration"""
|
||||||
|
_, components = create_input_column()
|
||||||
|
file_upload = components["file_upload"]
|
||||||
|
|
||||||
|
assert file_upload.label == "Upload Text File (.txt)"
|
||||||
|
assert file_upload.file_types == [".txt"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_button_configurations():
|
||||||
|
"""Test the button configurations"""
|
||||||
|
_, components = create_input_column()
|
||||||
|
|
||||||
|
# Test text submit button
|
||||||
|
assert components["text_submit"].value == "Generate Speech"
|
||||||
|
assert components["text_submit"].variant == "primary"
|
||||||
|
assert components["text_submit"].size == "lg"
|
||||||
|
|
||||||
|
# Test file submit button
|
||||||
|
assert components["file_submit"].value == "Generate Speech"
|
||||||
|
assert components["file_submit"].variant == "primary"
|
||||||
|
assert components["file_submit"].size == "lg"
|
||||||
|
|
||||||
|
# Test clear files button
|
||||||
|
assert components["clear_files"].value == "Clear Files"
|
||||||
|
assert components["clear_files"].variant == "secondary"
|
||||||
|
assert components["clear_files"].size == "lg"
|
139
ui/tests/test_interface.py
Normal file
139
ui/tests/test_interface.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
import pytest
|
||||||
|
import gradio as gr
|
||||||
|
from unittest.mock import patch, MagicMock, PropertyMock
|
||||||
|
from ui.lib.interface import create_interface
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_timer():
|
||||||
|
"""Create a mock timer with events property"""
|
||||||
|
class MockEvent:
|
||||||
|
def __init__(self, fn):
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
class MockTimer:
|
||||||
|
def __init__(self):
|
||||||
|
self._fn = None
|
||||||
|
self.value = 5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def events(self):
|
||||||
|
return [MockEvent(self._fn)] if self._fn else []
|
||||||
|
|
||||||
|
def tick(self, fn, outputs):
|
||||||
|
self._fn = fn
|
||||||
|
|
||||||
|
return MockTimer()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_interface_structure():
|
||||||
|
"""Test the basic structure of the created interface"""
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
|
demo = create_interface()
|
||||||
|
|
||||||
|
# Test interface type and theme
|
||||||
|
assert isinstance(demo, gr.Blocks)
|
||||||
|
assert demo.title == "Kokoro TTS Demo"
|
||||||
|
assert isinstance(demo.theme, gr.themes.Monochrome)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_html_links():
|
||||||
|
"""Test that HTML links are properly configured"""
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
|
demo = create_interface()
|
||||||
|
|
||||||
|
# Find HTML component
|
||||||
|
html_components = [
|
||||||
|
comp for comp in demo.blocks.values()
|
||||||
|
if isinstance(comp, gr.HTML)
|
||||||
|
]
|
||||||
|
assert len(html_components) > 0
|
||||||
|
html = html_components[0]
|
||||||
|
|
||||||
|
# Check for required links
|
||||||
|
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
|
||||||
|
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
|
||||||
|
assert "Kokoro-82M HF Repo" in html.value
|
||||||
|
assert "Kokoro-FastAPI Repo" in html.value
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_status_available(mock_timer):
|
||||||
|
"""Test status update when service is available"""
|
||||||
|
voices = ["voice1", "voice2"]
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
|
||||||
|
patch("gradio.Timer", return_value=mock_timer):
|
||||||
|
demo = create_interface()
|
||||||
|
|
||||||
|
# Get the update function
|
||||||
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
|
# Test update with available service
|
||||||
|
updates = update_fn()
|
||||||
|
|
||||||
|
assert "Available" in updates[0]["value"]
|
||||||
|
assert updates[1]["choices"] == voices
|
||||||
|
assert updates[1]["value"] == voices[0]
|
||||||
|
assert updates[2]["active"] is False # Timer should stop
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_status_unavailable(mock_timer):
|
||||||
|
"""Test status update when service is unavailable"""
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||||
|
patch("gradio.Timer", return_value=mock_timer):
|
||||||
|
demo = create_interface()
|
||||||
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
|
updates = update_fn()
|
||||||
|
|
||||||
|
assert "Waiting for Service" in updates[0]["value"]
|
||||||
|
assert updates[1]["choices"] == []
|
||||||
|
assert updates[1]["value"] is None
|
||||||
|
assert updates[2]["active"] is True # Timer should continue
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_status_error(mock_timer):
|
||||||
|
"""Test status update when an error occurs"""
|
||||||
|
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
|
||||||
|
patch("gradio.Timer", return_value=mock_timer):
|
||||||
|
demo = create_interface()
|
||||||
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
|
updates = update_fn()
|
||||||
|
|
||||||
|
assert "Connection Error" in updates[0]["value"]
|
||||||
|
assert updates[1]["choices"] == []
|
||||||
|
assert updates[1]["value"] is None
|
||||||
|
assert updates[2]["active"] is True # Timer should continue
|
||||||
|
|
||||||
|
|
||||||
|
def test_timer_configuration(mock_timer):
|
||||||
|
"""Test timer configuration"""
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||||
|
patch("gradio.Timer", return_value=mock_timer):
|
||||||
|
demo = create_interface()
|
||||||
|
|
||||||
|
assert mock_timer.value == 5 # Check interval is 5 seconds
|
||||||
|
assert len(mock_timer.events) == 1 # Should have one event handler
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_components_presence():
|
||||||
|
"""Test that all required components are present"""
|
||||||
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
|
demo = create_interface()
|
||||||
|
|
||||||
|
# Check for main component sections
|
||||||
|
components = {
|
||||||
|
comp.label for comp in demo.blocks.values()
|
||||||
|
if hasattr(comp, 'label') and comp.label
|
||||||
|
}
|
||||||
|
|
||||||
|
required_components = {
|
||||||
|
"Text to speak",
|
||||||
|
"Voice",
|
||||||
|
"Audio Format",
|
||||||
|
"Speed",
|
||||||
|
"Generated Speech",
|
||||||
|
"Previous Outputs"
|
||||||
|
}
|
||||||
|
|
||||||
|
assert required_components.issubset(components)
|
Loading…
Add table
Reference in a new issue