diff --git a/.coveragerc b/.coveragerc
index 4072f19..dab8655 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,5 +1,7 @@
[run]
-source = api
+source =
+ api
+ ui
omit =
Kokoro-82M/*
MagicMock/*
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 36715cd..44c98bf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,10 @@ Notable changes to this project will be documented in this file.
## 2024-01-09
+### Added
+- Gradio Web Interface:
+ - Added simple web UI utility for audio generation from input or txt file
+
### Modified
#### Configuration Changes
- Updated Docker configurations:
diff --git a/README.md b/README.md
index a626cc0..392ab68 100644
--- a/README.md
+++ b/README.md
@@ -3,42 +3,55 @@
# Kokoro TTS API
-[]()
-[]()
+[]()
+[]()
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
-FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
+Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
+- OpenAI-compatible Speech endpoint, with voice combination functionality
- NVIDIA GPU accelerated inference (or CPU) option
+- very fast generation time (~35x real time factor)
- automatic chunking/stitching for long texts
-- very fast generation time (~35-49x RTF)
+- simple audio generation web ui utility
-## Quick Start
+
+OpenAI-Compatible Speech Endpoint
+
+The service can be accessed through either the API endpoints or the Gradio web interface.
1. Install prerequisites:
- - Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
- - Install [Git](https://git-scm.com/downloads) (or download and extract zip)
+ - Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
+ - Clone and start the service:
+ ```bash
+ git clone https://github.com/remsky/Kokoro-FastAPI.git
+ cd Kokoro-FastAPI
+ docker compose up --build
+ ```
+2. Run locally as an OpenAI-Compatible Speech Endpoint
+ ```python
+ from openai import OpenAI
+ client = OpenAI(
+ base_url="http://localhost:8880",
+ api_key="not-needed"
+ )
-2. Clone and start the service:
-```bash
-# Clone repository
-git clone https://github.com/remsky/Kokoro-FastAPI.git
-cd Kokoro-FastAPI
+ response = client.audio.speech.create(
+ model="kokoro",
+ voice="af_bella",
+ input="Hello world!",
+ response_format="mp3"
+ )
+ response.stream_to_file("output.mp3")
+ ```
-# For GPU acceleration (requires NVIDIA GPU):
-docker compose up --build
+ or visit http://localhost:7860
+
+
+
+
+
+OpenAI-Compatible Speech Endpoint
-# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU):
-docker compose -f docker-compose.cpu.yml up --build
-```
-Quick tests (run from another terminal):
-```bash
-# Test OpenAI Compatibility
-python examples/test_openai_tts.py
-# Test all available voices
-python examples/test_all_voices.py
-```
-
-## OpenAI-Compatible API
```python
# Using OpenAI's Python library
from openai import OpenAI
@@ -77,16 +90,26 @@ with open("output.mp3", "wb") as f:
f.write(response.content)
```
-## Voice Combination
+Quick tests (run from another terminal):
+```bash
+python examples/test_openai_tts.py # Test OpenAI Compatibility
+python examples/test_all_voices.py # Test all available voices
+```
+
+
+
+Voice Combination
Combine voices and generate audio:
```python
import requests
+response = requests.get("http://localhost:8880/v1/audio/voices")
+voices = response.json()["voices"]
-# Create combined voice (saved locally on server)
+# Create combined voice (saves locally on server)
response = requests.post(
"http://localhost:8880/v1/audio/voices/combine",
- json=["af_bella", "af_sarah"]
+ json=[voices[0], voices[1]]
)
combined_voice = response.json()["voice"]
@@ -100,8 +123,27 @@ response = requests.post(
}
)
```
+
+
+
+
-## Performance Benchmarks
+
+Gradio Web Utility
+
+Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
+- Voice/format/speed selection
+- Audio playback and download
+- Text file or direct input
+
+If you only want the API, just comment out everything in the docker-compose.yml under and including `gradio-ui`
+
+Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
+
+
+
+
+Performance Benchmarks
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
- Windows 11 Home w/ WSL2
@@ -119,10 +161,22 @@ Benchmarking was performed on generation via the local API using text lengths up
Key Performance Metrics:
- Realtime Factor: Ranges between 35-49x (generation time to output audio length)
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
+
+
+GPU Vs. CPU
-## Features
+```bash
+# GPU: Requires NVIDIA GPU with CUDA 12.1 support
+docker compose up --build
-- OpenAI-compatible API endpoints
+# CPU: ~10x slower than GPU inference
+docker compose -f docker-compose.cpu.yml up --build
+```
+
+
+Features
+
+- OpenAI-compatible API endpoints (with optional Gradio Web UI)
- GPU-accelerated inference (if desired)
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
- Natural Boundary Detection:
@@ -131,19 +185,21 @@ Key Performance Metrics:
- Averages model weights of any existing voicepacks
- Saves generated voicepacks for future use
-
-
-
+
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
+
-## Model
+
+Model
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
+
-## License
+
+License
This project is licensed under the Apache License 2.0 - see below for details:
@@ -152,3 +208,4 @@ This project is licensed under the Apache License 2.0 - see below for details:
- The inference code adapted from StyleTTS2 is MIT licensed
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
+
diff --git a/docker-compose.yml b/docker-compose.yml
index 565d158..2722208 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -46,14 +46,14 @@ services:
model-fetcher:
condition: service_healthy
- # # Gradio UI service
- # gradio-ui:
- # build:
- # context: ./ui
- # ports:
- # - "7860:7860"
- # volumes:
- # - ./ui/data:/app/ui/data
- # - ./ui/app.py:/app/app.py # Mount app.py for hot reload
- # environment:
- # - GRADIO_WATCH=True # Enable hot reloading
+ # Gradio UI service [Comment out everything below if you don't need it]
+ gradio-ui:
+ build:
+ context: ./ui
+ ports:
+ - "7860:7860"
+ volumes:
+ - ./ui/data:/app/ui/data
+ - ./ui/app.py:/app/app.py # Mount app.py for hot reload
+ environment:
+ - GRADIO_WATCH=True # Enable hot reloading
diff --git a/pytest.ini b/pytest.ini
index 3bcd461..47be4b5 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,5 +1,5 @@
[pytest]
-testpaths = api/tests
+testpaths = api/tests ui/tests
python_files = test_*.py
-addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
+addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc
pythonpath = .
diff --git a/ui/GradioScreenShot.png b/ui/GradioScreenShot.png
new file mode 100644
index 0000000..77af6b3
Binary files /dev/null and b/ui/GradioScreenShot.png differ
diff --git a/ui/lib/api.py b/ui/lib/api.py
index 20e8b1d..a9c6a19 100644
--- a/ui/lib/api.py
+++ b/ui/lib/api.py
@@ -7,7 +7,11 @@ from .config import API_URL, OUTPUTS_DIR
def check_api_status() -> Tuple[bool, List[str]]:
"""Check TTS service status and get available voices."""
try:
- response = requests.get(f"{API_URL}/v1/audio/voices", timeout=5)
+ # Use a longer timeout during startup
+ response = requests.get(
+ f"{API_URL}/v1/audio/voices",
+ timeout=30 # Increased timeout for initial startup period
+ )
response.raise_for_status()
voices = response.json().get("voices", [])
if voices:
@@ -15,7 +19,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
print("No voices found in response")
return False, []
except requests.exceptions.Timeout:
- print("API request timed out")
+ print("API request timed out (waiting for service startup)")
+ return False, []
+ except requests.exceptions.ConnectionError as e:
+ print(f"Connection error (service may be starting up): {str(e)}")
return False, []
except requests.exceptions.RequestException as e:
print(f"API request failed: {str(e)}")
diff --git a/ui/lib/components/__init__.py b/ui/lib/components/__init__.py
index 637ee14..0d66be3 100644
--- a/ui/lib/components/__init__.py
+++ b/ui/lib/components/__init__.py
@@ -2,4 +2,4 @@ from .input import create_input_column
from .model import create_model_column
from .output import create_output_column
-__all__ = ['create_input_column', 'create_model_column', 'create_output_column']
+__all__ = ["create_input_column", "create_model_column", "create_output_column"]
diff --git a/ui/lib/components/input.py b/ui/lib/components/input.py
index a80ecc9..2644060 100644
--- a/ui/lib/components/input.py
+++ b/ui/lib/components/input.py
@@ -6,6 +6,8 @@ def create_input_column() -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling."""
with gr.Column(scale=1) as col:
with gr.Tabs() as tabs:
+ # Set first tab as selected by default
+ tabs.selected = 0
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_input = gr.Textbox(
@@ -13,6 +15,11 @@ def create_input_column() -> Tuple[gr.Column, dict]:
placeholder="Enter text here...",
lines=4
)
+ text_submit = gr.Button(
+ "Generate Speech",
+ variant="primary",
+ size="lg"
+ )
# File Input Tab
with gr.TabItem("From File"):
@@ -34,13 +41,28 @@ def create_input_column() -> Tuple[gr.Column, dict]:
interactive=False,
lines=4
)
+
+ with gr.Row():
+ file_submit = gr.Button(
+ "Generate Speech",
+ variant="primary",
+ size="lg"
+ )
+ clear_files = gr.Button(
+ "Clear Files",
+ variant="secondary",
+ size="lg"
+ )
components = {
"tabs": tabs,
"text_input": text_input,
"file_select": input_files_list,
"file_upload": file_upload,
- "file_preview": file_preview
+ "file_preview": file_preview,
+ "text_submit": text_submit,
+ "file_submit": file_submit,
+ "clear_files": clear_files
}
return col, components
diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py
index 41bcbfc..3b7ae96 100644
--- a/ui/lib/components/model.py
+++ b/ui/lib/components/model.py
@@ -10,10 +10,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
with gr.Column(scale=1) as col:
gr.Markdown("### Model Settings")
- # Status button with embedded status
- is_available, _ = api.check_api_status()
+ # Status button starts in waiting state
status_btn = gr.Button(
- f"Checking TTS Service: {'Available' if is_available else 'Not Yet Available'}",
+ "⌛ TTS Service: Waiting for Service...",
variant="secondary"
)
@@ -35,19 +34,12 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
step=0.1,
label="Speed"
)
-
- submit_btn = gr.Button(
- "Generate Speech",
- variant="primary",
- size="lg"
- )
components = {
"status_btn": status_btn,
"voice": voice_input,
"format": format_input,
- "speed": speed_input,
- "submit": submit_btn
+ "speed": speed_input
}
return col, components
diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py
index ff951fd..8ef4640 100644
--- a/ui/lib/components/output.py
+++ b/ui/lib/components/output.py
@@ -26,12 +26,15 @@ def create_output_column() -> Tuple[gr.Column, dict]:
type="filepath",
visible=False
)
+
+ clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
components = {
"audio_output": audio_output,
"output_files": output_files,
"play_btn": play_btn,
- "selected_audio": selected_audio
+ "selected_audio": selected_audio,
+ "clear_outputs": clear_outputs
}
return col, components
diff --git a/ui/lib/files.py b/ui/lib/files.py
index 66d44ce..98867f3 100644
--- a/ui/lib/files.py
+++ b/ui/lib/files.py
@@ -56,6 +56,30 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
print(f"Error saving file: {e}")
return None
+def delete_all_input_files() -> bool:
+ """Delete all files from the inputs directory. Returns True if successful."""
+ try:
+ for filename in os.listdir(INPUTS_DIR):
+ if filename.endswith('.txt'):
+ file_path = os.path.join(INPUTS_DIR, filename)
+ os.remove(file_path)
+ return True
+ except Exception as e:
+ print(f"Error deleting input files: {e}")
+ return False
+
+def delete_all_output_files() -> bool:
+ """Delete all audio files from the outputs directory. Returns True if successful."""
+ try:
+ for filename in os.listdir(OUTPUTS_DIR):
+ if any(filename.endswith(ext) for ext in AUDIO_FORMATS):
+ file_path = os.path.join(OUTPUTS_DIR, filename)
+ os.remove(file_path)
+ return True
+ except Exception as e:
+ print(f"Error deleting output files: {e}")
+ return False
+
def process_uploaded_file(file_path: str) -> bool:
"""Save uploaded file to inputs directory. Returns True if successful."""
if not file_path:
diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py
index bcc15d7..94c9574 100644
--- a/ui/lib/handlers.py
+++ b/ui/lib/handlers.py
@@ -7,19 +7,40 @@ def setup_event_handlers(components: dict):
"""Set up all event handlers for the UI components."""
def refresh_status():
- is_available, voices = api.check_api_status()
- status = "Available" if is_available else "Unavailable"
- btn_text = f"🔄 TTS Service: {status}"
-
- if is_available and voices:
- return {
- components["model"]["status_btn"]: gr.update(value=btn_text),
- components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
- }
- return {
- components["model"]["status_btn"]: gr.update(value=btn_text),
- components["model"]["voice"]: gr.update(choices=[], value=None)
- }
+ try:
+ is_available, voices = api.check_api_status()
+ status = "Available" if is_available else "Waiting for Service..."
+
+ if is_available and voices:
+ # Preserve current voice selection if it exists and is still valid
+ current_voice = components["model"]["voice"].value
+ default_voice = current_voice if current_voice in voices else voices[0]
+ return [
+ gr.update(
+ value=f"🔄 TTS Service: {status}",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=voices, value=default_voice)
+ ]
+ return [
+ gr.update(
+ value=f"⌛ TTS Service: {status}",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=[], value=None)
+ ]
+ except Exception as e:
+ print(f"Error in refresh status: {str(e)}")
+ return [
+ gr.update(
+ value="❌ TTS Service: Connection Error",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=[], value=None)
+ ]
def handle_file_select(filename):
if filename:
@@ -56,45 +77,95 @@ def setup_event_handlers(components: dict):
return gr.update(choices=files.list_input_files())
- def generate_speech(text, selected_file, voice, format, speed):
+ def generate_from_text(text, voice, format, speed):
+ """Generate speech from direct text input"""
is_available, _ = api.check_api_status()
if not is_available:
gr.Warning("TTS Service is currently unavailable")
- return {
- components["output"]["audio_output"]: None,
- components["output"]["output_files"]: gr.update(choices=files.list_output_files())
- }
-
- # Use text input if provided, otherwise use file content
- if text and text.strip():
- files.save_text(text)
- final_text = text
- elif selected_file:
- final_text = files.read_text_file(selected_file)
- else:
- gr.Warning("Please enter text or select a file")
- return {
- components["output"]["audio_output"]: None,
- components["output"]["output_files"]: gr.update(choices=files.list_output_files())
- }
-
- result = api.text_to_speech(final_text, voice, format, speed)
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
+
+ if not text or not text.strip():
+ gr.Warning("Please enter text in the input box")
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
+
+ files.save_text(text)
+ result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
- return {
- components["output"]["audio_output"]: None,
- components["output"]["output_files"]: gr.update(choices=files.list_output_files())
- }
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
- return {
- components["output"]["audio_output"]: result,
- components["output"]["output_files"]: gr.update(choices=files.list_output_files(), value=os.path.basename(result))
- }
+ return [
+ result,
+ gr.update(choices=files.list_output_files(), value=os.path.basename(result))
+ ]
+
+ def generate_from_file(selected_file, voice, format, speed):
+ """Generate speech from selected file"""
+ is_available, _ = api.check_api_status()
+ if not is_available:
+ gr.Warning("TTS Service is currently unavailable")
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
+
+ if not selected_file:
+ gr.Warning("Please select a file")
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
+
+ text = files.read_text_file(selected_file)
+ result = api.text_to_speech(text, voice, format, speed)
+ if result is None:
+ gr.Warning("Failed to generate speech. Please try again.")
+ return [
+ None,
+ gr.update(choices=files.list_output_files())
+ ]
+
+ return [
+ result,
+ gr.update(choices=files.list_output_files(), value=os.path.basename(result))
+ ]
def play_selected(file_path):
if file_path and os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
return gr.update(visible=False)
+
+ def clear_files(voice, format, speed):
+ """Delete all input files and clear UI components while preserving model settings"""
+ files.delete_all_input_files()
+ return [
+ gr.update(value=None, choices=[]), # file_select
+ None, # file_upload
+ gr.update(value=""), # file_preview
+ None, # audio_output
+ gr.update(choices=files.list_output_files()), # output_files
+ gr.update(value=voice), # voice
+ gr.update(value=format), # format
+ gr.update(value=speed) # speed
+ ]
+
+ def clear_outputs():
+ """Delete all output audio files and clear audio components"""
+ files.delete_all_output_files()
+ return [
+ None, # audio_output
+ gr.update(choices=[], value=None), # output_files
+ gr.update(visible=False) # selected_audio
+ ]
# Connect event handlers
components["model"]["status_btn"].click(
@@ -123,10 +194,54 @@ def setup_event_handlers(components: dict):
outputs=[components["output"]["selected_audio"]]
)
- components["model"]["submit"].click(
- fn=generate_speech,
+ # Connect clear files button
+ components["input"]["clear_files"].click(
+ fn=clear_files,
+ inputs=[
+ components["model"]["voice"],
+ components["model"]["format"],
+ components["model"]["speed"]
+ ],
+ outputs=[
+ components["input"]["file_select"],
+ components["input"]["file_upload"],
+ components["input"]["file_preview"],
+ components["output"]["audio_output"],
+ components["output"]["output_files"],
+ components["model"]["voice"],
+ components["model"]["format"],
+ components["model"]["speed"]
+ ]
+ )
+
+ # Connect submit buttons for each tab
+ components["input"]["text_submit"].click(
+ fn=generate_from_text,
inputs=[
components["input"]["text_input"],
+ components["model"]["voice"],
+ components["model"]["format"],
+ components["model"]["speed"]
+ ],
+ outputs=[
+ components["output"]["audio_output"],
+ components["output"]["output_files"]
+ ]
+ )
+
+ # Connect clear outputs button
+ components["output"]["clear_outputs"].click(
+ fn=clear_outputs,
+ outputs=[
+ components["output"]["audio_output"],
+ components["output"]["output_files"],
+ components["output"]["selected_audio"]
+ ]
+ )
+
+ components["input"]["file_submit"].click(
+ fn=generate_from_file,
+ inputs=[
components["input"]["file_select"],
components["model"]["voice"],
components["model"]["format"],
diff --git a/ui/lib/interface.py b/ui/lib/interface.py
index cfdada4..5361217 100644
--- a/ui/lib/interface.py
+++ b/ui/lib/interface.py
@@ -5,8 +5,8 @@ from .handlers import setup_event_handlers
def create_interface():
"""Create the main Gradio interface."""
- # Initial status check
- is_available, available_voices = api.check_api_status()
+ # Skip initial status check - let the timer handle it
+ is_available, available_voices = False, []
with gr.Blocks(
title="Kokoro TTS Demo",
@@ -36,19 +36,55 @@ def create_interface():
# Add periodic status check with Timer
def update_status():
- is_available, voices = api.check_api_status()
- status = "Available" if is_available else "Unavailable"
- return {
- components["model"]["status_btn"]: gr.update(value=f"🔄 TTS Service: {status}"),
- components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None)
- }
+ try:
+ is_available, voices = api.check_api_status()
+ status = "Available" if is_available else "Waiting for Service..."
+
+ if is_available and voices:
+ # Service is available, update UI and stop timer
+ current_voice = components["model"]["voice"].value
+ default_voice = current_voice if current_voice in voices else voices[0]
+ # Return values in same order as outputs list
+ return [
+ gr.update(
+ value=f"🔄 TTS Service: {status}",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=voices, value=default_voice),
+ gr.update(active=False) # Stop timer
+ ]
+
+ # Service not available yet, keep checking
+ return [
+ gr.update(
+ value=f"⌛ TTS Service: {status}",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=[], value=None),
+ gr.update(active=True)
+ ]
+ except Exception as e:
+ print(f"Error in status update: {str(e)}")
+ # On error, keep the timer running but show error state
+ return [
+ gr.update(
+ value="❌ TTS Service: Connection Error",
+ interactive=True,
+ variant="secondary"
+ ),
+ gr.update(choices=[], value=None),
+ gr.update(active=True)
+ ]
- timer = gr.Timer(10, active=True) # Check every 10 seconds
+ timer = gr.Timer(value=5) # Check every 5 seconds
timer.tick(
fn=update_status,
outputs=[
components["model"]["status_btn"],
- components["model"]["voice"]
+ components["model"]["voice"],
+ timer
]
)
diff --git a/ui/tests/conftest.py b/ui/tests/conftest.py
new file mode 100644
index 0000000..05ae58d
--- /dev/null
+++ b/ui/tests/conftest.py
@@ -0,0 +1,9 @@
+import pytest
+import gradio as gr
+
+
+@pytest.fixture
+def mock_gr_context():
+ """Provides a context for testing Gradio components"""
+ with gr.Blocks():
+ yield
diff --git a/ui/tests/test_api.py b/ui/tests/test_api.py
new file mode 100644
index 0000000..c9b37db
--- /dev/null
+++ b/ui/tests/test_api.py
@@ -0,0 +1,129 @@
+import pytest
+import requests
+from unittest.mock import patch, mock_open
+from ui.lib import api
+
+
+@pytest.fixture
+def mock_response():
+ class MockResponse:
+ def __init__(self, json_data, status_code=200, content=b"audio data"):
+ self._json = json_data
+ self.status_code = status_code
+ self.content = content
+
+ def json(self):
+ return self._json
+
+ def raise_for_status(self):
+ if self.status_code != 200:
+ raise requests.exceptions.HTTPError(f"HTTP {self.status_code}")
+
+ return MockResponse
+
+
+def test_check_api_status_success(mock_response):
+ """Test successful API status check"""
+ mock_data = {"voices": ["voice1", "voice2"]}
+ with patch("requests.get", return_value=mock_response(mock_data)):
+ status, voices = api.check_api_status()
+ assert status is True
+ assert voices == ["voice1", "voice2"]
+
+
+def test_check_api_status_no_voices(mock_response):
+ """Test API response with no voices"""
+ with patch("requests.get", return_value=mock_response({"voices": []})):
+ status, voices = api.check_api_status()
+ assert status is False
+ assert voices == []
+
+
+def test_check_api_status_timeout():
+ """Test API timeout"""
+ with patch("requests.get", side_effect=requests.exceptions.Timeout):
+ status, voices = api.check_api_status()
+ assert status is False
+ assert voices == []
+
+
+def test_check_api_status_connection_error():
+ """Test API connection error"""
+ with patch("requests.get", side_effect=requests.exceptions.ConnectionError):
+ status, voices = api.check_api_status()
+ assert status is False
+ assert voices == []
+
+
+def test_text_to_speech_success(mock_response, tmp_path):
+ """Test successful speech generation"""
+ with patch("requests.post", return_value=mock_response({})), \
+ patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
+ patch("builtins.open", mock_open()) as mock_file:
+
+ result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
+
+ assert result is not None
+ assert "output_" in result
+ assert result.endswith(".mp3")
+ mock_file.assert_called_once()
+
+
+def test_text_to_speech_empty_text():
+ """Test speech generation with empty text"""
+ result = api.text_to_speech("", "voice1", "mp3", 1.0)
+ assert result is None
+
+
+def test_text_to_speech_timeout():
+ """Test speech generation timeout"""
+ with patch("requests.post", side_effect=requests.exceptions.Timeout):
+ result = api.text_to_speech("test", "voice1", "mp3", 1.0)
+ assert result is None
+
+
+def test_text_to_speech_request_error():
+ """Test speech generation request error"""
+ with patch("requests.post", side_effect=requests.exceptions.RequestException):
+ result = api.text_to_speech("test", "voice1", "mp3", 1.0)
+ assert result is None
+
+
+def test_get_status_html_available():
+ """Test status HTML generation for available service"""
+ html = api.get_status_html(True)
+ assert "green" in html
+ assert "Available" in html
+
+
+def test_get_status_html_unavailable():
+ """Test status HTML generation for unavailable service"""
+ html = api.get_status_html(False)
+ assert "red" in html
+ assert "Unavailable" in html
+
+
+def test_text_to_speech_api_params(mock_response, tmp_path):
+ """Test correct API parameters are sent"""
+ with patch("requests.post") as mock_post, \
+ patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
+ patch("builtins.open", mock_open()):
+
+ mock_post.return_value = mock_response({})
+ api.text_to_speech("test text", "voice1", "mp3", 1.5)
+
+ mock_post.assert_called_once()
+ args, kwargs = mock_post.call_args
+
+ # Check request body
+ assert kwargs["json"] == {
+ "model": "kokoro",
+ "input": "test text",
+ "voice": "voice1",
+ "response_format": "mp3",
+ "speed": 1.5
+ }
+
+ # Check headers and timeout
+ assert kwargs["headers"] == {"Content-Type": "application/json"}
+ assert kwargs["timeout"] == 300
diff --git a/ui/tests/test_components.py b/ui/tests/test_components.py
new file mode 100644
index 0000000..9ddb1ad
--- /dev/null
+++ b/ui/tests/test_components.py
@@ -0,0 +1,116 @@
+import pytest
+import gradio as gr
+from ui.lib.components.model import create_model_column
+from ui.lib.components.output import create_output_column
+from ui.lib.config import AUDIO_FORMATS
+
+
+def test_create_model_column_structure():
+ """Test that create_model_column returns the expected structure"""
+ voice_ids = ["voice1", "voice2"]
+ column, components = create_model_column(voice_ids)
+
+ # Test return types
+ assert isinstance(column, gr.Column)
+ assert isinstance(components, dict)
+
+ # Test expected components presence
+ expected_components = {
+ "status_btn",
+ "voice",
+ "format",
+ "speed"
+ }
+ assert set(components.keys()) == expected_components
+
+ # Test component types
+ assert isinstance(components["status_btn"], gr.Button)
+ assert isinstance(components["voice"], gr.Dropdown)
+ assert isinstance(components["format"], gr.Dropdown)
+ assert isinstance(components["speed"], gr.Slider)
+
+
+def test_model_column_default_values():
+ """Test the default values of model column components"""
+ voice_ids = ["voice1", "voice2"]
+ _, components = create_model_column(voice_ids)
+
+ # Test voice dropdown
+ # Gradio Dropdown converts choices to (value, label) tuples
+ expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
+ assert components["voice"].choices == expected_choices
+ # Value is not converted to tuple format for the value property
+ assert components["voice"].value == voice_ids[0]
+ assert components["voice"].interactive is True
+
+ # Test format dropdown
+ # Gradio Dropdown converts choices to (value, label) tuples
+ expected_format_choices = [(fmt, fmt) for fmt in AUDIO_FORMATS]
+ assert components["format"].choices == expected_format_choices
+ assert components["format"].value == "mp3"
+
+ # Test speed slider
+ assert components["speed"].minimum == 0.5
+ assert components["speed"].maximum == 2.0
+ assert components["speed"].value == 1.0
+ assert components["speed"].step == 0.1
+
+
+def test_model_column_no_voices():
+ """Test model column creation with no voice IDs"""
+ _, components = create_model_column()
+
+ assert components["voice"].choices == []
+ assert components["voice"].value is None
+
+
+def test_create_output_column_structure():
+ """Test that create_output_column returns the expected structure"""
+ column, components = create_output_column()
+
+ # Test return types
+ assert isinstance(column, gr.Column)
+ assert isinstance(components, dict)
+
+ # Test expected components presence
+ expected_components = {
+ "audio_output",
+ "output_files",
+ "play_btn",
+ "selected_audio",
+ "clear_outputs"
+ }
+ assert set(components.keys()) == expected_components
+
+ # Test component types
+ assert isinstance(components["audio_output"], gr.Audio)
+ assert isinstance(components["output_files"], gr.Dropdown)
+ assert isinstance(components["play_btn"], gr.Button)
+ assert isinstance(components["selected_audio"], gr.Audio)
+ assert isinstance(components["clear_outputs"], gr.Button)
+
+
+def test_output_column_configuration():
+ """Test the configuration of output column components"""
+ _, components = create_output_column()
+
+ # Test audio output configuration
+ assert components["audio_output"].label == "Generated Speech"
+ assert components["audio_output"].type == "filepath"
+
+ # Test output files dropdown
+ assert components["output_files"].label == "Previous Outputs"
+ assert components["output_files"].allow_custom_value is False
+
+ # Test play button
+ assert components["play_btn"].value == "▶️ Play Selected"
+ assert components["play_btn"].size == "sm"
+
+ # Test selected audio configuration
+ assert components["selected_audio"].label == "Selected Output"
+ assert components["selected_audio"].type == "filepath"
+ assert components["selected_audio"].visible is False
+
+ # Test clear outputs button
+ assert components["clear_outputs"].size == "sm"
+ assert components["clear_outputs"].variant == "secondary"
diff --git a/ui/tests/test_files.py b/ui/tests/test_files.py
new file mode 100644
index 0000000..aaa0fe8
--- /dev/null
+++ b/ui/tests/test_files.py
@@ -0,0 +1,195 @@
+import os
+import pytest
+from unittest.mock import patch
+from ui.lib import files
+from ui.lib.config import AUDIO_FORMATS
+
+
+@pytest.fixture
+def mock_dirs(tmp_path):
+ """Create temporary input and output directories"""
+ inputs_dir = tmp_path / "inputs"
+ outputs_dir = tmp_path / "outputs"
+ inputs_dir.mkdir()
+ outputs_dir.mkdir()
+
+ with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
+ "ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
+ ):
+ yield inputs_dir, outputs_dir
+
+
+def test_list_input_files_empty(mock_dirs):
+ """Test listing input files from empty directory"""
+ assert files.list_input_files() == []
+
+
+def test_list_input_files(mock_dirs):
+ """Test listing input files with various files"""
+ inputs_dir, _ = mock_dirs
+
+ # Create test files
+ (inputs_dir / "test1.txt").write_text("content1")
+ (inputs_dir / "test2.txt").write_text("content2")
+ (inputs_dir / "nottext.pdf").write_text("should not be listed")
+
+ result = files.list_input_files()
+ assert len(result) == 2
+ assert "test1.txt" in result
+ assert "test2.txt" in result
+ assert "nottext.pdf" not in result
+
+
+def test_list_output_files_empty(mock_dirs):
+ """Test listing output files from empty directory"""
+ assert files.list_output_files() == []
+
+
+def test_list_output_files(mock_dirs):
+ """Test listing output files with various formats"""
+ _, outputs_dir = mock_dirs
+
+ # Create test files for each format
+ for fmt in AUDIO_FORMATS:
+ (outputs_dir / f"test.{fmt}").write_text("dummy content")
+ (outputs_dir / "test.txt").write_text("should not be listed")
+
+ result = files.list_output_files()
+ assert len(result) == len(AUDIO_FORMATS)
+ for fmt in AUDIO_FORMATS:
+ assert any(f".{fmt}" in file for file in result)
+
+
+def test_read_text_file_empty_filename(mock_dirs):
+ """Test reading with empty filename"""
+ assert files.read_text_file("") == ""
+
+
+def test_read_text_file_nonexistent(mock_dirs):
+ """Test reading nonexistent file"""
+ assert files.read_text_file("nonexistent.txt") == ""
+
+
+def test_read_text_file_success(mock_dirs):
+ """Test successful file reading"""
+ inputs_dir, _ = mock_dirs
+ content = "Test content\nMultiple lines"
+ (inputs_dir / "test.txt").write_text(content)
+
+ assert files.read_text_file("test.txt") == content
+
+
+def test_save_text_empty(mock_dirs):
+ """Test saving empty text"""
+ assert files.save_text("") is None
+ assert files.save_text(" ") is None
+
+
+def test_save_text_auto_filename(mock_dirs):
+ """Test saving text with auto-generated filename"""
+ inputs_dir, _ = mock_dirs
+
+ # First save
+ filename1 = files.save_text("content1")
+ assert filename1 == "input_1.txt"
+ assert (inputs_dir / filename1).read_text() == "content1"
+
+ # Second save
+ filename2 = files.save_text("content2")
+ assert filename2 == "input_2.txt"
+ assert (inputs_dir / filename2).read_text() == "content2"
+
+
+def test_save_text_custom_filename(mock_dirs):
+ """Test saving text with custom filename"""
+ inputs_dir, _ = mock_dirs
+
+ filename = files.save_text("content", "custom.txt")
+ assert filename == "custom.txt"
+ assert (inputs_dir / filename).read_text() == "content"
+
+
+def test_save_text_duplicate_filename(mock_dirs):
+ """Test saving text with duplicate filename"""
+ inputs_dir, _ = mock_dirs
+
+ # First save
+ filename1 = files.save_text("content1", "test.txt")
+ assert filename1 == "test.txt"
+
+ # Save with same filename
+ filename2 = files.save_text("content2", "test.txt")
+ assert filename2 == "test_1.txt"
+
+ assert (inputs_dir / "test.txt").read_text() == "content1"
+ assert (inputs_dir / "test_1.txt").read_text() == "content2"
+
+
+def test_delete_all_input_files(mock_dirs):
+ """Test deleting all input files"""
+ inputs_dir, _ = mock_dirs
+
+ # Create test files
+ (inputs_dir / "test1.txt").write_text("content1")
+ (inputs_dir / "test2.txt").write_text("content2")
+ (inputs_dir / "keep.pdf").write_text("should not be deleted")
+
+ assert files.delete_all_input_files() is True
+ remaining_files = list(inputs_dir.iterdir())
+ assert len(remaining_files) == 1
+ assert remaining_files[0].name == "keep.pdf"
+
+
+def test_delete_all_output_files(mock_dirs):
+ """Test deleting all output files"""
+ _, outputs_dir = mock_dirs
+
+ # Create test files
+ for fmt in AUDIO_FORMATS:
+ (outputs_dir / f"test.{fmt}").write_text("dummy content")
+ (outputs_dir / "keep.txt").write_text("should not be deleted")
+
+ assert files.delete_all_output_files() is True
+ remaining_files = list(outputs_dir.iterdir())
+ assert len(remaining_files) == 1
+ assert remaining_files[0].name == "keep.txt"
+
+
+def test_process_uploaded_file_empty_path(mock_dirs):
+ """Test processing empty file path"""
+ assert files.process_uploaded_file("") is False
+
+
+def test_process_uploaded_file_invalid_extension(mock_dirs, tmp_path):
+ """Test processing file with invalid extension"""
+ test_file = tmp_path / "test.pdf"
+ test_file.write_text("content")
+ assert files.process_uploaded_file(str(test_file)) is False
+
+
+def test_process_uploaded_file_success(mock_dirs, tmp_path):
+ """Test successful file upload processing"""
+ inputs_dir, _ = mock_dirs
+
+ # Create source file
+ source_file = tmp_path / "test.txt"
+ source_file.write_text("test content")
+
+ assert files.process_uploaded_file(str(source_file)) is True
+ assert (inputs_dir / "test.txt").read_text() == "test content"
+
+
+def test_process_uploaded_file_duplicate(mock_dirs, tmp_path):
+ """Test processing file with duplicate name"""
+ inputs_dir, _ = mock_dirs
+
+ # Create existing file
+ (inputs_dir / "test.txt").write_text("existing content")
+
+ # Create source file
+ source_file = tmp_path / "test.txt"
+ source_file.write_text("new content")
+
+ assert files.process_uploaded_file(str(source_file)) is True
+ assert (inputs_dir / "test.txt").read_text() == "existing content"
+ assert (inputs_dir / "test_1.txt").read_text() == "new content"
diff --git a/ui/tests/test_handlers.py b/ui/tests/test_handlers.py
new file mode 100644
index 0000000..86a71b0
--- /dev/null
+++ b/ui/tests/test_handlers.py
@@ -0,0 +1,4 @@
+"""
+Drop all tests for now. The Gradio event system is too complex to test properly.
+We'll need to find a better way to test the UI functionality.
+"""
diff --git a/ui/tests/test_input.py b/ui/tests/test_input.py
new file mode 100644
index 0000000..807a483
--- /dev/null
+++ b/ui/tests/test_input.py
@@ -0,0 +1,74 @@
+import pytest
+import gradio as gr
+from ui.lib.components.input import create_input_column
+
+
+def test_create_input_column_structure():
+ """Test that create_input_column returns the expected structure"""
+ column, components = create_input_column()
+
+ # Test the return types
+ assert isinstance(column, gr.Column)
+ assert isinstance(components, dict)
+
+ # Test that all expected components are present
+ expected_components = {
+ "tabs",
+ "text_input",
+ "file_select",
+ "file_upload",
+ "file_preview",
+ "text_submit",
+ "file_submit",
+ "clear_files",
+ }
+ assert set(components.keys()) == expected_components
+
+ # Test component types
+ assert isinstance(components["tabs"], gr.Tabs)
+ assert isinstance(components["text_input"], gr.Textbox)
+ assert isinstance(components["file_select"], gr.Dropdown)
+ assert isinstance(components["file_upload"], gr.File)
+ assert isinstance(components["file_preview"], gr.Textbox)
+ assert isinstance(components["text_submit"], gr.Button)
+ assert isinstance(components["file_submit"], gr.Button)
+ assert isinstance(components["clear_files"], gr.Button)
+
+
+def test_text_input_configuration():
+ """Test the text input component configuration"""
+ _, components = create_input_column()
+ text_input = components["text_input"]
+
+ assert text_input.label == "Text to speak"
+ assert text_input.placeholder == "Enter text here..."
+ assert text_input.lines == 4
+
+
+def test_file_upload_configuration():
+ """Test the file upload component configuration"""
+ _, components = create_input_column()
+ file_upload = components["file_upload"]
+
+ assert file_upload.label == "Upload Text File (.txt)"
+ assert file_upload.file_types == [".txt"]
+
+
+def test_button_configurations():
+ """Test the button configurations"""
+ _, components = create_input_column()
+
+ # Test text submit button
+ assert components["text_submit"].value == "Generate Speech"
+ assert components["text_submit"].variant == "primary"
+ assert components["text_submit"].size == "lg"
+
+ # Test file submit button
+ assert components["file_submit"].value == "Generate Speech"
+ assert components["file_submit"].variant == "primary"
+ assert components["file_submit"].size == "lg"
+
+ # Test clear files button
+ assert components["clear_files"].value == "Clear Files"
+ assert components["clear_files"].variant == "secondary"
+ assert components["clear_files"].size == "lg"
diff --git a/ui/tests/test_interface.py b/ui/tests/test_interface.py
new file mode 100644
index 0000000..550591f
--- /dev/null
+++ b/ui/tests/test_interface.py
@@ -0,0 +1,139 @@
+import pytest
+import gradio as gr
+from unittest.mock import patch, MagicMock, PropertyMock
+from ui.lib.interface import create_interface
+
+
+@pytest.fixture
+def mock_timer():
+ """Create a mock timer with events property"""
+ class MockEvent:
+ def __init__(self, fn):
+ self.fn = fn
+
+ class MockTimer:
+ def __init__(self):
+ self._fn = None
+ self.value = 5
+
+ @property
+ def events(self):
+ return [MockEvent(self._fn)] if self._fn else []
+
+ def tick(self, fn, outputs):
+ self._fn = fn
+
+ return MockTimer()
+
+
+def test_create_interface_structure():
+ """Test the basic structure of the created interface"""
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])):
+ demo = create_interface()
+
+ # Test interface type and theme
+ assert isinstance(demo, gr.Blocks)
+ assert demo.title == "Kokoro TTS Demo"
+ assert isinstance(demo.theme, gr.themes.Monochrome)
+
+
+def test_interface_html_links():
+ """Test that HTML links are properly configured"""
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])):
+ demo = create_interface()
+
+ # Find HTML component
+ html_components = [
+ comp for comp in demo.blocks.values()
+ if isinstance(comp, gr.HTML)
+ ]
+ assert len(html_components) > 0
+ html = html_components[0]
+
+ # Check for required links
+ assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
+ assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
+ assert "Kokoro-82M HF Repo" in html.value
+ assert "Kokoro-FastAPI Repo" in html.value
+
+
+def test_update_status_available(mock_timer):
+ """Test status update when service is available"""
+ voices = ["voice1", "voice2"]
+ with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
+ patch("gradio.Timer", return_value=mock_timer):
+ demo = create_interface()
+
+ # Get the update function
+ update_fn = mock_timer.events[0].fn
+
+ # Test update with available service
+ updates = update_fn()
+
+ assert "Available" in updates[0]["value"]
+ assert updates[1]["choices"] == voices
+ assert updates[1]["value"] == voices[0]
+ assert updates[2]["active"] is False # Timer should stop
+
+
+def test_update_status_unavailable(mock_timer):
+ """Test status update when service is unavailable"""
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
+ patch("gradio.Timer", return_value=mock_timer):
+ demo = create_interface()
+ update_fn = mock_timer.events[0].fn
+
+ updates = update_fn()
+
+ assert "Waiting for Service" in updates[0]["value"]
+ assert updates[1]["choices"] == []
+ assert updates[1]["value"] is None
+ assert updates[2]["active"] is True # Timer should continue
+
+
+def test_update_status_error(mock_timer):
+ """Test status update when an error occurs"""
+ with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
+ patch("gradio.Timer", return_value=mock_timer):
+ demo = create_interface()
+ update_fn = mock_timer.events[0].fn
+
+ updates = update_fn()
+
+ assert "Connection Error" in updates[0]["value"]
+ assert updates[1]["choices"] == []
+ assert updates[1]["value"] is None
+ assert updates[2]["active"] is True # Timer should continue
+
+
+def test_timer_configuration(mock_timer):
+ """Test timer configuration"""
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
+ patch("gradio.Timer", return_value=mock_timer):
+ demo = create_interface()
+
+ assert mock_timer.value == 5 # Check interval is 5 seconds
+ assert len(mock_timer.events) == 1 # Should have one event handler
+
+
+def test_interface_components_presence():
+ """Test that all required components are present"""
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])):
+ demo = create_interface()
+
+ # Check for main component sections
+ components = {
+ comp.label for comp in demo.blocks.values()
+ if hasattr(comp, 'label') and comp.label
+ }
+
+ required_components = {
+ "Text to speak",
+ "Voice",
+ "Audio Format",
+ "Speed",
+ "Generated Speech",
+ "Previous Outputs"
+ }
+
+ assert required_components.issubset(components)