mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add Gradio web interface + tests
This commit is contained in:
parent
19321eabb2
commit
e749b3bc88
21 changed files with 1048 additions and 120 deletions
|
@ -1,5 +1,7 @@
|
|||
[run]
|
||||
source = api
|
||||
source =
|
||||
api
|
||||
ui
|
||||
omit =
|
||||
Kokoro-82M/*
|
||||
MagicMock/*
|
||||
|
|
|
@ -4,6 +4,10 @@ Notable changes to this project will be documented in this file.
|
|||
|
||||
## 2024-01-09
|
||||
|
||||
### Added
|
||||
- Gradio Web Interface:
|
||||
- Added simple web UI utility for audio generation from input or txt file
|
||||
|
||||
### Modified
|
||||
#### Configuration Changes
|
||||
- Updated Docker configurations:
|
||||
|
|
131
README.md
131
README.md
|
@ -3,42 +3,55 @@
|
|||
</p>
|
||||
|
||||
# Kokoro TTS API
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
||||
|
||||
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
- OpenAI-compatible Speech endpoint, with voice combination functionality
|
||||
- NVIDIA GPU accelerated inference (or CPU) option
|
||||
- very fast generation time (~35x real time factor)
|
||||
- automatic chunking/stitching for long texts
|
||||
- very fast generation time (~35-49x RTF)
|
||||
- simple audio generation web ui utility
|
||||
|
||||
## Quick Start
|
||||
<details open>
|
||||
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
||||
|
||||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||
|
||||
1. Install prerequisites:
|
||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
|
||||
- Install [Git](https://git-scm.com/downloads) (or download and extract zip)
|
||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
|
||||
- Clone and start the service:
|
||||
```bash
|
||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||
cd Kokoro-FastAPI
|
||||
docker compose up --build
|
||||
```
|
||||
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8880",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
2. Clone and start the service:
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||
cd Kokoro-FastAPI
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
)
|
||||
response.stream_to_file("output.mp3")
|
||||
```
|
||||
|
||||
# For GPU acceleration (requires NVIDIA GPU):
|
||||
docker compose up --build
|
||||
or visit http://localhost:7860
|
||||
<p align="center">
|
||||
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
</details>
|
||||
<details>
|
||||
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
||||
|
||||
# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU):
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
Quick tests (run from another terminal):
|
||||
```bash
|
||||
# Test OpenAI Compatibility
|
||||
python examples/test_openai_tts.py
|
||||
# Test all available voices
|
||||
python examples/test_all_voices.py
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
```python
|
||||
# Using OpenAI's Python library
|
||||
from openai import OpenAI
|
||||
|
@ -77,16 +90,26 @@ with open("output.mp3", "wb") as f:
|
|||
f.write(response.content)
|
||||
```
|
||||
|
||||
## Voice Combination
|
||||
Quick tests (run from another terminal):
|
||||
```bash
|
||||
python examples/test_openai_tts.py # Test OpenAI Compatibility
|
||||
python examples/test_all_voices.py # Test all available voices
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Voice Combination</b></summary>
|
||||
|
||||
Combine voices and generate audio:
|
||||
```python
|
||||
import requests
|
||||
response = requests.get("http://localhost:8880/v1/audio/voices")
|
||||
voices = response.json()["voices"]
|
||||
|
||||
# Create combined voice (saved locally on server)
|
||||
# Create combined voice (saves locally on server)
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/voices/combine",
|
||||
json=["af_bella", "af_sarah"]
|
||||
json=[voices[0], voices[1]]
|
||||
)
|
||||
combined_voice = response.json()["voice"]
|
||||
|
||||
|
@ -100,8 +123,27 @@ response = requests.post(
|
|||
}
|
||||
)
|
||||
```
|
||||
<p align="center">
|
||||
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
</details>
|
||||
|
||||
## Performance Benchmarks
|
||||
<details>
|
||||
<summary><b>Gradio Web Utility</b></summary>
|
||||
|
||||
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
|
||||
- Voice/format/speed selection
|
||||
- Audio playback and download
|
||||
- Text file or direct input
|
||||
|
||||
If you only want the API, just comment out everything in the docker-compose.yml under and including `gradio-ui`
|
||||
|
||||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary><b>Performance Benchmarks</b></summary>
|
||||
|
||||
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
||||
- Windows 11 Home w/ WSL2
|
||||
|
@ -119,10 +161,22 @@ Benchmarking was performed on generation via the local API using text lengths up
|
|||
Key Performance Metrics:
|
||||
- Realtime Factor: Ranges between 35-49x (generation time to output audio length)
|
||||
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
||||
</details>
|
||||
<details>
|
||||
<summary><b>GPU Vs. CPU<b></summary>
|
||||
|
||||
## Features
|
||||
```bash
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
||||
docker compose up --build
|
||||
|
||||
- OpenAI-compatible API endpoints
|
||||
# CPU: ~10x slower than GPU inference
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
</details>
|
||||
<details>
|
||||
<summary><b>Features</b></summary>
|
||||
|
||||
- OpenAI-compatible API endpoints (with optional Gradio Web UI)
|
||||
- GPU-accelerated inference (if desired)
|
||||
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
||||
- Natural Boundary Detection:
|
||||
|
@ -131,19 +185,21 @@ Key Performance Metrics:
|
|||
- Averages model weights of any existing voicepacks
|
||||
- Saves generated voicepacks for future use
|
||||
|
||||
<p align="center">
|
||||
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
|
||||
|
||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||
</details>
|
||||
|
||||
## Model
|
||||
<details open>
|
||||
<summary><b>Model</b></summary>
|
||||
|
||||
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
||||
|
||||
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
||||
</details>
|
||||
|
||||
## License
|
||||
<details>
|
||||
<summary><b>License</b></summary>
|
||||
|
||||
This project is licensed under the Apache License 2.0 - see below for details:
|
||||
|
||||
|
@ -152,3 +208,4 @@ This project is licensed under the Apache License 2.0 - see below for details:
|
|||
- The inference code adapted from StyleTTS2 is MIT licensed
|
||||
|
||||
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
||||
</details>
|
||||
|
|
|
@ -46,14 +46,14 @@ services:
|
|||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
||||
# # Gradio UI service
|
||||
# gradio-ui:
|
||||
# build:
|
||||
# context: ./ui
|
||||
# ports:
|
||||
# - "7860:7860"
|
||||
# volumes:
|
||||
# - ./ui/data:/app/ui/data
|
||||
# - ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
# environment:
|
||||
# - GRADIO_WATCH=True # Enable hot reloading
|
||||
# Gradio UI service [Comment out everything below if you don't need it]
|
||||
gradio-ui:
|
||||
build:
|
||||
context: ./ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ./ui/data:/app/ui/data
|
||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[pytest]
|
||||
testpaths = api/tests
|
||||
testpaths = api/tests ui/tests
|
||||
python_files = test_*.py
|
||||
addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
|
||||
addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc
|
||||
pythonpath = .
|
||||
|
|
BIN
ui/GradioScreenShot.png
Normal file
BIN
ui/GradioScreenShot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 113 KiB |
|
@ -7,7 +7,11 @@ from .config import API_URL, OUTPUTS_DIR
|
|||
def check_api_status() -> Tuple[bool, List[str]]:
|
||||
"""Check TTS service status and get available voices."""
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/v1/audio/voices", timeout=5)
|
||||
# Use a longer timeout during startup
|
||||
response = requests.get(
|
||||
f"{API_URL}/v1/audio/voices",
|
||||
timeout=30 # Increased timeout for initial startup period
|
||||
)
|
||||
response.raise_for_status()
|
||||
voices = response.json().get("voices", [])
|
||||
if voices:
|
||||
|
@ -15,7 +19,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
|||
print("No voices found in response")
|
||||
return False, []
|
||||
except requests.exceptions.Timeout:
|
||||
print("API request timed out")
|
||||
print("API request timed out (waiting for service startup)")
|
||||
return False, []
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f"Connection error (service may be starting up): {str(e)}")
|
||||
return False, []
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"API request failed: {str(e)}")
|
||||
|
|
|
@ -2,4 +2,4 @@ from .input import create_input_column
|
|||
from .model import create_model_column
|
||||
from .output import create_output_column
|
||||
|
||||
__all__ = ['create_input_column', 'create_model_column', 'create_output_column']
|
||||
__all__ = ["create_input_column", "create_model_column", "create_output_column"]
|
||||
|
|
|
@ -6,6 +6,8 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
"""Create the input column with text input and file handling."""
|
||||
with gr.Column(scale=1) as col:
|
||||
with gr.Tabs() as tabs:
|
||||
# Set first tab as selected by default
|
||||
tabs.selected = 0
|
||||
# Direct Input Tab
|
||||
with gr.TabItem("Direct Input"):
|
||||
text_input = gr.Textbox(
|
||||
|
@ -13,6 +15,11 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
placeholder="Enter text here...",
|
||||
lines=4
|
||||
)
|
||||
text_submit = gr.Button(
|
||||
"Generate Speech",
|
||||
variant="primary",
|
||||
size="lg"
|
||||
)
|
||||
|
||||
# File Input Tab
|
||||
with gr.TabItem("From File"):
|
||||
|
@ -35,12 +42,27 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
lines=4
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
file_submit = gr.Button(
|
||||
"Generate Speech",
|
||||
variant="primary",
|
||||
size="lg"
|
||||
)
|
||||
clear_files = gr.Button(
|
||||
"Clear Files",
|
||||
variant="secondary",
|
||||
size="lg"
|
||||
)
|
||||
|
||||
components = {
|
||||
"tabs": tabs,
|
||||
"text_input": text_input,
|
||||
"file_select": input_files_list,
|
||||
"file_upload": file_upload,
|
||||
"file_preview": file_preview
|
||||
"file_preview": file_preview,
|
||||
"text_submit": text_submit,
|
||||
"file_submit": file_submit,
|
||||
"clear_files": clear_files
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -10,10 +10,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
|||
with gr.Column(scale=1) as col:
|
||||
gr.Markdown("### Model Settings")
|
||||
|
||||
# Status button with embedded status
|
||||
is_available, _ = api.check_api_status()
|
||||
# Status button starts in waiting state
|
||||
status_btn = gr.Button(
|
||||
f"Checking TTS Service: {'Available' if is_available else 'Not Yet Available'}",
|
||||
"⌛ TTS Service: Waiting for Service...",
|
||||
variant="secondary"
|
||||
)
|
||||
|
||||
|
@ -36,18 +35,11 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
|||
label="Speed"
|
||||
)
|
||||
|
||||
submit_btn = gr.Button(
|
||||
"Generate Speech",
|
||||
variant="primary",
|
||||
size="lg"
|
||||
)
|
||||
|
||||
components = {
|
||||
"status_btn": status_btn,
|
||||
"voice": voice_input,
|
||||
"format": format_input,
|
||||
"speed": speed_input,
|
||||
"submit": submit_btn
|
||||
"speed": speed_input
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -27,11 +27,14 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
|||
visible=False
|
||||
)
|
||||
|
||||
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
|
||||
|
||||
components = {
|
||||
"audio_output": audio_output,
|
||||
"output_files": output_files,
|
||||
"play_btn": play_btn,
|
||||
"selected_audio": selected_audio
|
||||
"selected_audio": selected_audio,
|
||||
"clear_outputs": clear_outputs
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -56,6 +56,30 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
|||
print(f"Error saving file: {e}")
|
||||
return None
|
||||
|
||||
def delete_all_input_files() -> bool:
|
||||
"""Delete all files from the inputs directory. Returns True if successful."""
|
||||
try:
|
||||
for filename in os.listdir(INPUTS_DIR):
|
||||
if filename.endswith('.txt'):
|
||||
file_path = os.path.join(INPUTS_DIR, filename)
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting input files: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_output_files() -> bool:
|
||||
"""Delete all audio files from the outputs directory. Returns True if successful."""
|
||||
try:
|
||||
for filename in os.listdir(OUTPUTS_DIR):
|
||||
if any(filename.endswith(ext) for ext in AUDIO_FORMATS):
|
||||
file_path = os.path.join(OUTPUTS_DIR, filename)
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting output files: {e}")
|
||||
return False
|
||||
|
||||
def process_uploaded_file(file_path: str) -> bool:
|
||||
"""Save uploaded file to inputs directory. Returns True if successful."""
|
||||
if not file_path:
|
||||
|
|
|
@ -7,19 +7,40 @@ def setup_event_handlers(components: dict):
|
|||
"""Set up all event handlers for the UI components."""
|
||||
|
||||
def refresh_status():
|
||||
is_available, voices = api.check_api_status()
|
||||
status = "Available" if is_available else "Unavailable"
|
||||
btn_text = f"🔄 TTS Service: {status}"
|
||||
try:
|
||||
is_available, voices = api.check_api_status()
|
||||
status = "Available" if is_available else "Waiting for Service..."
|
||||
|
||||
if is_available and voices:
|
||||
return {
|
||||
components["model"]["status_btn"]: gr.update(value=btn_text),
|
||||
components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
|
||||
}
|
||||
return {
|
||||
components["model"]["status_btn"]: gr.update(value=btn_text),
|
||||
components["model"]["voice"]: gr.update(choices=[], value=None)
|
||||
}
|
||||
if is_available and voices:
|
||||
# Preserve current voice selection if it exists and is still valid
|
||||
current_voice = components["model"]["voice"].value
|
||||
default_voice = current_voice if current_voice in voices else voices[0]
|
||||
return [
|
||||
gr.update(
|
||||
value=f"🔄 TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=voices, value=default_voice)
|
||||
]
|
||||
return [
|
||||
gr.update(
|
||||
value=f"⌛ TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=[], value=None)
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Error in refresh status: {str(e)}")
|
||||
return [
|
||||
gr.update(
|
||||
value="❌ TTS Service: Connection Error",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=[], value=None)
|
||||
]
|
||||
|
||||
def handle_file_select(filename):
|
||||
if filename:
|
||||
|
@ -56,46 +77,96 @@ def setup_event_handlers(components: dict):
|
|||
|
||||
return gr.update(choices=files.list_input_files())
|
||||
|
||||
def generate_speech(text, selected_file, voice, format, speed):
|
||||
def generate_from_text(text, voice, format, speed):
|
||||
"""Generate speech from direct text input"""
|
||||
is_available, _ = api.check_api_status()
|
||||
if not is_available:
|
||||
gr.Warning("TTS Service is currently unavailable")
|
||||
return {
|
||||
components["output"]["audio_output"]: None,
|
||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
||||
}
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
# Use text input if provided, otherwise use file content
|
||||
if text and text.strip():
|
||||
files.save_text(text)
|
||||
final_text = text
|
||||
elif selected_file:
|
||||
final_text = files.read_text_file(selected_file)
|
||||
else:
|
||||
gr.Warning("Please enter text or select a file")
|
||||
return {
|
||||
components["output"]["audio_output"]: None,
|
||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
||||
}
|
||||
if not text or not text.strip():
|
||||
gr.Warning("Please enter text in the input box")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
result = api.text_to_speech(final_text, voice, format, speed)
|
||||
files.save_text(text)
|
||||
result = api.text_to_speech(text, voice, format, speed)
|
||||
if result is None:
|
||||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return {
|
||||
components["output"]["audio_output"]: None,
|
||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files())
|
||||
}
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
return {
|
||||
components["output"]["audio_output"]: result,
|
||||
components["output"]["output_files"]: gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||
}
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||
]
|
||||
|
||||
def generate_from_file(selected_file, voice, format, speed):
|
||||
"""Generate speech from selected file"""
|
||||
is_available, _ = api.check_api_status()
|
||||
if not is_available:
|
||||
gr.Warning("TTS Service is currently unavailable")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
if not selected_file:
|
||||
gr.Warning("Please select a file")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
text = files.read_text_file(selected_file)
|
||||
result = api.text_to_speech(text, voice, format, speed)
|
||||
if result is None:
|
||||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||
]
|
||||
|
||||
def play_selected(file_path):
|
||||
if file_path and os.path.exists(file_path):
|
||||
return gr.update(value=file_path, visible=True)
|
||||
return gr.update(visible=False)
|
||||
|
||||
def clear_files(voice, format, speed):
|
||||
"""Delete all input files and clear UI components while preserving model settings"""
|
||||
files.delete_all_input_files()
|
||||
return [
|
||||
gr.update(value=None, choices=[]), # file_select
|
||||
None, # file_upload
|
||||
gr.update(value=""), # file_preview
|
||||
None, # audio_output
|
||||
gr.update(choices=files.list_output_files()), # output_files
|
||||
gr.update(value=voice), # voice
|
||||
gr.update(value=format), # format
|
||||
gr.update(value=speed) # speed
|
||||
]
|
||||
|
||||
def clear_outputs():
|
||||
"""Delete all output audio files and clear audio components"""
|
||||
files.delete_all_output_files()
|
||||
return [
|
||||
None, # audio_output
|
||||
gr.update(choices=[], value=None), # output_files
|
||||
gr.update(visible=False) # selected_audio
|
||||
]
|
||||
|
||||
# Connect event handlers
|
||||
components["model"]["status_btn"].click(
|
||||
fn=refresh_status,
|
||||
|
@ -123,10 +194,54 @@ def setup_event_handlers(components: dict):
|
|||
outputs=[components["output"]["selected_audio"]]
|
||||
)
|
||||
|
||||
components["model"]["submit"].click(
|
||||
fn=generate_speech,
|
||||
# Connect clear files button
|
||||
components["input"]["clear_files"].click(
|
||||
fn=clear_files,
|
||||
inputs=[
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
],
|
||||
outputs=[
|
||||
components["input"]["file_select"],
|
||||
components["input"]["file_upload"],
|
||||
components["input"]["file_preview"],
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
]
|
||||
)
|
||||
|
||||
# Connect submit buttons for each tab
|
||||
components["input"]["text_submit"].click(
|
||||
fn=generate_from_text,
|
||||
inputs=[
|
||||
components["input"]["text_input"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
],
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"]
|
||||
]
|
||||
)
|
||||
|
||||
# Connect clear outputs button
|
||||
components["output"]["clear_outputs"].click(
|
||||
fn=clear_outputs,
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["output"]["selected_audio"]
|
||||
]
|
||||
)
|
||||
|
||||
components["input"]["file_submit"].click(
|
||||
fn=generate_from_file,
|
||||
inputs=[
|
||||
components["input"]["file_select"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
|
|
|
@ -5,8 +5,8 @@ from .handlers import setup_event_handlers
|
|||
|
||||
def create_interface():
|
||||
"""Create the main Gradio interface."""
|
||||
# Initial status check
|
||||
is_available, available_voices = api.check_api_status()
|
||||
# Skip initial status check - let the timer handle it
|
||||
is_available, available_voices = False, []
|
||||
|
||||
with gr.Blocks(
|
||||
title="Kokoro TTS Demo",
|
||||
|
@ -36,19 +36,55 @@ def create_interface():
|
|||
|
||||
# Add periodic status check with Timer
|
||||
def update_status():
|
||||
is_available, voices = api.check_api_status()
|
||||
status = "Available" if is_available else "Unavailable"
|
||||
return {
|
||||
components["model"]["status_btn"]: gr.update(value=f"🔄 TTS Service: {status}"),
|
||||
components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
|
||||
}
|
||||
try:
|
||||
is_available, voices = api.check_api_status()
|
||||
status = "Available" if is_available else "Waiting for Service..."
|
||||
|
||||
timer = gr.Timer(10, active=True) # Check every 10 seconds
|
||||
if is_available and voices:
|
||||
# Service is available, update UI and stop timer
|
||||
current_voice = components["model"]["voice"].value
|
||||
default_voice = current_voice if current_voice in voices else voices[0]
|
||||
# Return values in same order as outputs list
|
||||
return [
|
||||
gr.update(
|
||||
value=f"🔄 TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=voices, value=default_voice),
|
||||
gr.update(active=False) # Stop timer
|
||||
]
|
||||
|
||||
# Service not available yet, keep checking
|
||||
return [
|
||||
gr.update(
|
||||
value=f"⌛ TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=[], value=None),
|
||||
gr.update(active=True)
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Error in status update: {str(e)}")
|
||||
# On error, keep the timer running but show error state
|
||||
return [
|
||||
gr.update(
|
||||
value="❌ TTS Service: Connection Error",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
),
|
||||
gr.update(choices=[], value=None),
|
||||
gr.update(active=True)
|
||||
]
|
||||
|
||||
timer = gr.Timer(value=5) # Check every 5 seconds
|
||||
timer.tick(
|
||||
fn=update_status,
|
||||
outputs=[
|
||||
components["model"]["status_btn"],
|
||||
components["model"]["voice"]
|
||||
components["model"]["voice"],
|
||||
timer
|
||||
]
|
||||
)
|
||||
|
||||
|
|
9
ui/tests/conftest.py
Normal file
9
ui/tests/conftest.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gr_context():
|
||||
"""Provides a context for testing Gradio components"""
|
||||
with gr.Blocks():
|
||||
yield
|
129
ui/tests/test_api.py
Normal file
129
ui/tests/test_api.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import pytest
|
||||
import requests
|
||||
from unittest.mock import patch, mock_open
|
||||
from ui.lib import api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code=200, content=b"audio data"):
|
||||
self._json = json_data
|
||||
self.status_code = status_code
|
||||
self.content = content
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
raise requests.exceptions.HTTPError(f"HTTP {self.status_code}")
|
||||
|
||||
return MockResponse
|
||||
|
||||
|
||||
def test_check_api_status_success(mock_response):
|
||||
"""Test successful API status check"""
|
||||
mock_data = {"voices": ["voice1", "voice2"]}
|
||||
with patch("requests.get", return_value=mock_response(mock_data)):
|
||||
status, voices = api.check_api_status()
|
||||
assert status is True
|
||||
assert voices == ["voice1", "voice2"]
|
||||
|
||||
|
||||
def test_check_api_status_no_voices(mock_response):
|
||||
"""Test API response with no voices"""
|
||||
with patch("requests.get", return_value=mock_response({"voices": []})):
|
||||
status, voices = api.check_api_status()
|
||||
assert status is False
|
||||
assert voices == []
|
||||
|
||||
|
||||
def test_check_api_status_timeout():
|
||||
"""Test API timeout"""
|
||||
with patch("requests.get", side_effect=requests.exceptions.Timeout):
|
||||
status, voices = api.check_api_status()
|
||||
assert status is False
|
||||
assert voices == []
|
||||
|
||||
|
||||
def test_check_api_status_connection_error():
|
||||
"""Test API connection error"""
|
||||
with patch("requests.get", side_effect=requests.exceptions.ConnectionError):
|
||||
status, voices = api.check_api_status()
|
||||
assert status is False
|
||||
assert voices == []
|
||||
|
||||
|
||||
def test_text_to_speech_success(mock_response, tmp_path):
|
||||
"""Test successful speech generation"""
|
||||
with patch("requests.post", return_value=mock_response({})), \
|
||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||
patch("builtins.open", mock_open()) as mock_file:
|
||||
|
||||
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||
|
||||
assert result is not None
|
||||
assert "output_" in result
|
||||
assert result.endswith(".mp3")
|
||||
mock_file.assert_called_once()
|
||||
|
||||
|
||||
def test_text_to_speech_empty_text():
|
||||
"""Test speech generation with empty text"""
|
||||
result = api.text_to_speech("", "voice1", "mp3", 1.0)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_text_to_speech_timeout():
|
||||
"""Test speech generation timeout"""
|
||||
with patch("requests.post", side_effect=requests.exceptions.Timeout):
|
||||
result = api.text_to_speech("test", "voice1", "mp3", 1.0)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_text_to_speech_request_error():
|
||||
"""Test speech generation request error"""
|
||||
with patch("requests.post", side_effect=requests.exceptions.RequestException):
|
||||
result = api.text_to_speech("test", "voice1", "mp3", 1.0)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_status_html_available():
|
||||
"""Test status HTML generation for available service"""
|
||||
html = api.get_status_html(True)
|
||||
assert "green" in html
|
||||
assert "Available" in html
|
||||
|
||||
|
||||
def test_get_status_html_unavailable():
|
||||
"""Test status HTML generation for unavailable service"""
|
||||
html = api.get_status_html(False)
|
||||
assert "red" in html
|
||||
assert "Unavailable" in html
|
||||
|
||||
|
||||
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||
"""Test correct API parameters are sent"""
|
||||
with patch("requests.post") as mock_post, \
|
||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||
patch("builtins.open", mock_open()):
|
||||
|
||||
mock_post.return_value = mock_response({})
|
||||
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
|
||||
# Check request body
|
||||
assert kwargs["json"] == {
|
||||
"model": "kokoro",
|
||||
"input": "test text",
|
||||
"voice": "voice1",
|
||||
"response_format": "mp3",
|
||||
"speed": 1.5
|
||||
}
|
||||
|
||||
# Check headers and timeout
|
||||
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||
assert kwargs["timeout"] == 300
|
116
ui/tests/test_components.py
Normal file
116
ui/tests/test_components.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
from ui.lib.components.model import create_model_column
|
||||
from ui.lib.components.output import create_output_column
|
||||
from ui.lib.config import AUDIO_FORMATS
|
||||
|
||||
|
||||
def test_create_model_column_structure():
|
||||
"""Test that create_model_column returns the expected structure"""
|
||||
voice_ids = ["voice1", "voice2"]
|
||||
column, components = create_model_column(voice_ids)
|
||||
|
||||
# Test return types
|
||||
assert isinstance(column, gr.Column)
|
||||
assert isinstance(components, dict)
|
||||
|
||||
# Test expected components presence
|
||||
expected_components = {
|
||||
"status_btn",
|
||||
"voice",
|
||||
"format",
|
||||
"speed"
|
||||
}
|
||||
assert set(components.keys()) == expected_components
|
||||
|
||||
# Test component types
|
||||
assert isinstance(components["status_btn"], gr.Button)
|
||||
assert isinstance(components["voice"], gr.Dropdown)
|
||||
assert isinstance(components["format"], gr.Dropdown)
|
||||
assert isinstance(components["speed"], gr.Slider)
|
||||
|
||||
|
||||
def test_model_column_default_values():
|
||||
"""Test the default values of model column components"""
|
||||
voice_ids = ["voice1", "voice2"]
|
||||
_, components = create_model_column(voice_ids)
|
||||
|
||||
# Test voice dropdown
|
||||
# Gradio Dropdown converts choices to (value, label) tuples
|
||||
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
|
||||
assert components["voice"].choices == expected_choices
|
||||
# Value is not converted to tuple format for the value property
|
||||
assert components["voice"].value == voice_ids[0]
|
||||
assert components["voice"].interactive is True
|
||||
|
||||
# Test format dropdown
|
||||
# Gradio Dropdown converts choices to (value, label) tuples
|
||||
expected_format_choices = [(fmt, fmt) for fmt in AUDIO_FORMATS]
|
||||
assert components["format"].choices == expected_format_choices
|
||||
assert components["format"].value == "mp3"
|
||||
|
||||
# Test speed slider
|
||||
assert components["speed"].minimum == 0.5
|
||||
assert components["speed"].maximum == 2.0
|
||||
assert components["speed"].value == 1.0
|
||||
assert components["speed"].step == 0.1
|
||||
|
||||
|
||||
def test_model_column_no_voices():
|
||||
"""Test model column creation with no voice IDs"""
|
||||
_, components = create_model_column()
|
||||
|
||||
assert components["voice"].choices == []
|
||||
assert components["voice"].value is None
|
||||
|
||||
|
||||
def test_create_output_column_structure():
|
||||
"""Test that create_output_column returns the expected structure"""
|
||||
column, components = create_output_column()
|
||||
|
||||
# Test return types
|
||||
assert isinstance(column, gr.Column)
|
||||
assert isinstance(components, dict)
|
||||
|
||||
# Test expected components presence
|
||||
expected_components = {
|
||||
"audio_output",
|
||||
"output_files",
|
||||
"play_btn",
|
||||
"selected_audio",
|
||||
"clear_outputs"
|
||||
}
|
||||
assert set(components.keys()) == expected_components
|
||||
|
||||
# Test component types
|
||||
assert isinstance(components["audio_output"], gr.Audio)
|
||||
assert isinstance(components["output_files"], gr.Dropdown)
|
||||
assert isinstance(components["play_btn"], gr.Button)
|
||||
assert isinstance(components["selected_audio"], gr.Audio)
|
||||
assert isinstance(components["clear_outputs"], gr.Button)
|
||||
|
||||
|
||||
def test_output_column_configuration():
|
||||
"""Test the configuration of output column components"""
|
||||
_, components = create_output_column()
|
||||
|
||||
# Test audio output configuration
|
||||
assert components["audio_output"].label == "Generated Speech"
|
||||
assert components["audio_output"].type == "filepath"
|
||||
|
||||
# Test output files dropdown
|
||||
assert components["output_files"].label == "Previous Outputs"
|
||||
assert components["output_files"].allow_custom_value is False
|
||||
|
||||
# Test play button
|
||||
assert components["play_btn"].value == "▶️ Play Selected"
|
||||
assert components["play_btn"].size == "sm"
|
||||
|
||||
# Test selected audio configuration
|
||||
assert components["selected_audio"].label == "Selected Output"
|
||||
assert components["selected_audio"].type == "filepath"
|
||||
assert components["selected_audio"].visible is False
|
||||
|
||||
# Test clear outputs button
|
||||
assert components["clear_outputs"].size == "sm"
|
||||
assert components["clear_outputs"].variant == "secondary"
|
195
ui/tests/test_files.py
Normal file
195
ui/tests/test_files.py
Normal file
|
@ -0,0 +1,195 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from ui.lib import files
|
||||
from ui.lib.config import AUDIO_FORMATS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dirs(tmp_path):
|
||||
"""Create temporary input and output directories"""
|
||||
inputs_dir = tmp_path / "inputs"
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
inputs_dir.mkdir()
|
||||
outputs_dir.mkdir()
|
||||
|
||||
with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
|
||||
"ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
|
||||
):
|
||||
yield inputs_dir, outputs_dir
|
||||
|
||||
|
||||
def test_list_input_files_empty(mock_dirs):
|
||||
"""Test listing input files from empty directory"""
|
||||
assert files.list_input_files() == []
|
||||
|
||||
|
||||
def test_list_input_files(mock_dirs):
|
||||
"""Test listing input files with various files"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# Create test files
|
||||
(inputs_dir / "test1.txt").write_text("content1")
|
||||
(inputs_dir / "test2.txt").write_text("content2")
|
||||
(inputs_dir / "nottext.pdf").write_text("should not be listed")
|
||||
|
||||
result = files.list_input_files()
|
||||
assert len(result) == 2
|
||||
assert "test1.txt" in result
|
||||
assert "test2.txt" in result
|
||||
assert "nottext.pdf" not in result
|
||||
|
||||
|
||||
def test_list_output_files_empty(mock_dirs):
|
||||
"""Test listing output files from empty directory"""
|
||||
assert files.list_output_files() == []
|
||||
|
||||
|
||||
def test_list_output_files(mock_dirs):
|
||||
"""Test listing output files with various formats"""
|
||||
_, outputs_dir = mock_dirs
|
||||
|
||||
# Create test files for each format
|
||||
for fmt in AUDIO_FORMATS:
|
||||
(outputs_dir / f"test.{fmt}").write_text("dummy content")
|
||||
(outputs_dir / "test.txt").write_text("should not be listed")
|
||||
|
||||
result = files.list_output_files()
|
||||
assert len(result) == len(AUDIO_FORMATS)
|
||||
for fmt in AUDIO_FORMATS:
|
||||
assert any(f".{fmt}" in file for file in result)
|
||||
|
||||
|
||||
def test_read_text_file_empty_filename(mock_dirs):
|
||||
"""Test reading with empty filename"""
|
||||
assert files.read_text_file("") == ""
|
||||
|
||||
|
||||
def test_read_text_file_nonexistent(mock_dirs):
|
||||
"""Test reading nonexistent file"""
|
||||
assert files.read_text_file("nonexistent.txt") == ""
|
||||
|
||||
|
||||
def test_read_text_file_success(mock_dirs):
|
||||
"""Test successful file reading"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
content = "Test content\nMultiple lines"
|
||||
(inputs_dir / "test.txt").write_text(content)
|
||||
|
||||
assert files.read_text_file("test.txt") == content
|
||||
|
||||
|
||||
def test_save_text_empty(mock_dirs):
|
||||
"""Test saving empty text"""
|
||||
assert files.save_text("") is None
|
||||
assert files.save_text(" ") is None
|
||||
|
||||
|
||||
def test_save_text_auto_filename(mock_dirs):
|
||||
"""Test saving text with auto-generated filename"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# First save
|
||||
filename1 = files.save_text("content1")
|
||||
assert filename1 == "input_1.txt"
|
||||
assert (inputs_dir / filename1).read_text() == "content1"
|
||||
|
||||
# Second save
|
||||
filename2 = files.save_text("content2")
|
||||
assert filename2 == "input_2.txt"
|
||||
assert (inputs_dir / filename2).read_text() == "content2"
|
||||
|
||||
|
||||
def test_save_text_custom_filename(mock_dirs):
|
||||
"""Test saving text with custom filename"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
filename = files.save_text("content", "custom.txt")
|
||||
assert filename == "custom.txt"
|
||||
assert (inputs_dir / filename).read_text() == "content"
|
||||
|
||||
|
||||
def test_save_text_duplicate_filename(mock_dirs):
|
||||
"""Test saving text with duplicate filename"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# First save
|
||||
filename1 = files.save_text("content1", "test.txt")
|
||||
assert filename1 == "test.txt"
|
||||
|
||||
# Save with same filename
|
||||
filename2 = files.save_text("content2", "test.txt")
|
||||
assert filename2 == "test_1.txt"
|
||||
|
||||
assert (inputs_dir / "test.txt").read_text() == "content1"
|
||||
assert (inputs_dir / "test_1.txt").read_text() == "content2"
|
||||
|
||||
|
||||
def test_delete_all_input_files(mock_dirs):
|
||||
"""Test deleting all input files"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# Create test files
|
||||
(inputs_dir / "test1.txt").write_text("content1")
|
||||
(inputs_dir / "test2.txt").write_text("content2")
|
||||
(inputs_dir / "keep.pdf").write_text("should not be deleted")
|
||||
|
||||
assert files.delete_all_input_files() is True
|
||||
remaining_files = list(inputs_dir.iterdir())
|
||||
assert len(remaining_files) == 1
|
||||
assert remaining_files[0].name == "keep.pdf"
|
||||
|
||||
|
||||
def test_delete_all_output_files(mock_dirs):
|
||||
"""Test deleting all output files"""
|
||||
_, outputs_dir = mock_dirs
|
||||
|
||||
# Create test files
|
||||
for fmt in AUDIO_FORMATS:
|
||||
(outputs_dir / f"test.{fmt}").write_text("dummy content")
|
||||
(outputs_dir / "keep.txt").write_text("should not be deleted")
|
||||
|
||||
assert files.delete_all_output_files() is True
|
||||
remaining_files = list(outputs_dir.iterdir())
|
||||
assert len(remaining_files) == 1
|
||||
assert remaining_files[0].name == "keep.txt"
|
||||
|
||||
|
||||
def test_process_uploaded_file_empty_path(mock_dirs):
|
||||
"""Test processing empty file path"""
|
||||
assert files.process_uploaded_file("") is False
|
||||
|
||||
|
||||
def test_process_uploaded_file_invalid_extension(mock_dirs, tmp_path):
|
||||
"""Test processing file with invalid extension"""
|
||||
test_file = tmp_path / "test.pdf"
|
||||
test_file.write_text("content")
|
||||
assert files.process_uploaded_file(str(test_file)) is False
|
||||
|
||||
|
||||
def test_process_uploaded_file_success(mock_dirs, tmp_path):
|
||||
"""Test successful file upload processing"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# Create source file
|
||||
source_file = tmp_path / "test.txt"
|
||||
source_file.write_text("test content")
|
||||
|
||||
assert files.process_uploaded_file(str(source_file)) is True
|
||||
assert (inputs_dir / "test.txt").read_text() == "test content"
|
||||
|
||||
|
||||
def test_process_uploaded_file_duplicate(mock_dirs, tmp_path):
|
||||
"""Test processing file with duplicate name"""
|
||||
inputs_dir, _ = mock_dirs
|
||||
|
||||
# Create existing file
|
||||
(inputs_dir / "test.txt").write_text("existing content")
|
||||
|
||||
# Create source file
|
||||
source_file = tmp_path / "test.txt"
|
||||
source_file.write_text("new content")
|
||||
|
||||
assert files.process_uploaded_file(str(source_file)) is True
|
||||
assert (inputs_dir / "test.txt").read_text() == "existing content"
|
||||
assert (inputs_dir / "test_1.txt").read_text() == "new content"
|
4
ui/tests/test_handlers.py
Normal file
4
ui/tests/test_handlers.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
"""
|
||||
Drop all tests for now. The Gradio event system is too complex to test properly.
|
||||
We'll need to find a better way to test the UI functionality.
|
||||
"""
|
74
ui/tests/test_input.py
Normal file
74
ui/tests/test_input.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
from ui.lib.components.input import create_input_column
|
||||
|
||||
|
||||
def test_create_input_column_structure():
|
||||
"""Test that create_input_column returns the expected structure"""
|
||||
column, components = create_input_column()
|
||||
|
||||
# Test the return types
|
||||
assert isinstance(column, gr.Column)
|
||||
assert isinstance(components, dict)
|
||||
|
||||
# Test that all expected components are present
|
||||
expected_components = {
|
||||
"tabs",
|
||||
"text_input",
|
||||
"file_select",
|
||||
"file_upload",
|
||||
"file_preview",
|
||||
"text_submit",
|
||||
"file_submit",
|
||||
"clear_files",
|
||||
}
|
||||
assert set(components.keys()) == expected_components
|
||||
|
||||
# Test component types
|
||||
assert isinstance(components["tabs"], gr.Tabs)
|
||||
assert isinstance(components["text_input"], gr.Textbox)
|
||||
assert isinstance(components["file_select"], gr.Dropdown)
|
||||
assert isinstance(components["file_upload"], gr.File)
|
||||
assert isinstance(components["file_preview"], gr.Textbox)
|
||||
assert isinstance(components["text_submit"], gr.Button)
|
||||
assert isinstance(components["file_submit"], gr.Button)
|
||||
assert isinstance(components["clear_files"], gr.Button)
|
||||
|
||||
|
||||
def test_text_input_configuration():
|
||||
"""Test the text input component configuration"""
|
||||
_, components = create_input_column()
|
||||
text_input = components["text_input"]
|
||||
|
||||
assert text_input.label == "Text to speak"
|
||||
assert text_input.placeholder == "Enter text here..."
|
||||
assert text_input.lines == 4
|
||||
|
||||
|
||||
def test_file_upload_configuration():
|
||||
"""Test the file upload component configuration"""
|
||||
_, components = create_input_column()
|
||||
file_upload = components["file_upload"]
|
||||
|
||||
assert file_upload.label == "Upload Text File (.txt)"
|
||||
assert file_upload.file_types == [".txt"]
|
||||
|
||||
|
||||
def test_button_configurations():
|
||||
"""Test the button configurations"""
|
||||
_, components = create_input_column()
|
||||
|
||||
# Test text submit button
|
||||
assert components["text_submit"].value == "Generate Speech"
|
||||
assert components["text_submit"].variant == "primary"
|
||||
assert components["text_submit"].size == "lg"
|
||||
|
||||
# Test file submit button
|
||||
assert components["file_submit"].value == "Generate Speech"
|
||||
assert components["file_submit"].variant == "primary"
|
||||
assert components["file_submit"].size == "lg"
|
||||
|
||||
# Test clear files button
|
||||
assert components["clear_files"].value == "Clear Files"
|
||||
assert components["clear_files"].variant == "secondary"
|
||||
assert components["clear_files"].size == "lg"
|
139
ui/tests/test_interface.py
Normal file
139
ui/tests/test_interface.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
from ui.lib.interface import create_interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_timer():
|
||||
"""Create a mock timer with events property"""
|
||||
class MockEvent:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
class MockTimer:
|
||||
def __init__(self):
|
||||
self._fn = None
|
||||
self.value = 5
|
||||
|
||||
@property
|
||||
def events(self):
|
||||
return [MockEvent(self._fn)] if self._fn else []
|
||||
|
||||
def tick(self, fn, outputs):
|
||||
self._fn = fn
|
||||
|
||||
return MockTimer()
|
||||
|
||||
|
||||
def test_create_interface_structure():
|
||||
"""Test the basic structure of the created interface"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||
demo = create_interface()
|
||||
|
||||
# Test interface type and theme
|
||||
assert isinstance(demo, gr.Blocks)
|
||||
assert demo.title == "Kokoro TTS Demo"
|
||||
assert isinstance(demo.theme, gr.themes.Monochrome)
|
||||
|
||||
|
||||
def test_interface_html_links():
|
||||
"""Test that HTML links are properly configured"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||
demo = create_interface()
|
||||
|
||||
# Find HTML component
|
||||
html_components = [
|
||||
comp for comp in demo.blocks.values()
|
||||
if isinstance(comp, gr.HTML)
|
||||
]
|
||||
assert len(html_components) > 0
|
||||
html = html_components[0]
|
||||
|
||||
# Check for required links
|
||||
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
|
||||
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
|
||||
assert "Kokoro-82M HF Repo" in html.value
|
||||
assert "Kokoro-FastAPI Repo" in html.value
|
||||
|
||||
|
||||
def test_update_status_available(mock_timer):
|
||||
"""Test status update when service is available"""
|
||||
voices = ["voice1", "voice2"]
|
||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
demo = create_interface()
|
||||
|
||||
# Get the update function
|
||||
update_fn = mock_timer.events[0].fn
|
||||
|
||||
# Test update with available service
|
||||
updates = update_fn()
|
||||
|
||||
assert "Available" in updates[0]["value"]
|
||||
assert updates[1]["choices"] == voices
|
||||
assert updates[1]["value"] == voices[0]
|
||||
assert updates[2]["active"] is False # Timer should stop
|
||||
|
||||
|
||||
def test_update_status_unavailable(mock_timer):
|
||||
"""Test status update when service is unavailable"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
demo = create_interface()
|
||||
update_fn = mock_timer.events[0].fn
|
||||
|
||||
updates = update_fn()
|
||||
|
||||
assert "Waiting for Service" in updates[0]["value"]
|
||||
assert updates[1]["choices"] == []
|
||||
assert updates[1]["value"] is None
|
||||
assert updates[2]["active"] is True # Timer should continue
|
||||
|
||||
|
||||
def test_update_status_error(mock_timer):
|
||||
"""Test status update when an error occurs"""
|
||||
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
demo = create_interface()
|
||||
update_fn = mock_timer.events[0].fn
|
||||
|
||||
updates = update_fn()
|
||||
|
||||
assert "Connection Error" in updates[0]["value"]
|
||||
assert updates[1]["choices"] == []
|
||||
assert updates[1]["value"] is None
|
||||
assert updates[2]["active"] is True # Timer should continue
|
||||
|
||||
|
||||
def test_timer_configuration(mock_timer):
|
||||
"""Test timer configuration"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
demo = create_interface()
|
||||
|
||||
assert mock_timer.value == 5 # Check interval is 5 seconds
|
||||
assert len(mock_timer.events) == 1 # Should have one event handler
|
||||
|
||||
|
||||
def test_interface_components_presence():
|
||||
"""Test that all required components are present"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||
demo = create_interface()
|
||||
|
||||
# Check for main component sections
|
||||
components = {
|
||||
comp.label for comp in demo.blocks.values()
|
||||
if hasattr(comp, 'label') and comp.label
|
||||
}
|
||||
|
||||
required_components = {
|
||||
"Text to speak",
|
||||
"Voice",
|
||||
"Audio Format",
|
||||
"Speed",
|
||||
"Generated Speech",
|
||||
"Previous Outputs"
|
||||
}
|
||||
|
||||
assert required_components.issubset(components)
|
Loading…
Add table
Reference in a new issue