diff --git a/.coveragerc b/.coveragerc index 4072f19..dab8655 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,7 @@ [run] -source = api +source = + api + ui omit = Kokoro-82M/* MagicMock/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 36715cd..44c98bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ Notable changes to this project will be documented in this file. ## 2024-01-09 +### Added +- Gradio Web Interface: + - Added simple web UI utility for audio generation from input or txt file + ### Modified #### Configuration Changes - Updated Docker configurations: diff --git a/README.md b/README.md index a626cc0..392ab68 100644 --- a/README.md +++ b/README.md @@ -3,42 +3,55 @@

# Kokoro TTS API -[![Tests](https://img.shields.io/badge/tests-37%20passed-darkgreen)]() -[![Coverage](https://img.shields.io/badge/coverage-81%25-darkgreen)]() +[![Tests](https://img.shields.io/badge/tests-81%20passed-darkgreen)]() +[![Coverage](https://img.shields.io/badge/coverage-76%25-darkgreen)]() [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) -FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with: +Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model +- OpenAI-compatible Speech endpoint, with voice combination functionality - NVIDIA GPU accelerated inference (or CPU) option +- very fast generation time (~35x real time factor) - automatic chunking/stitching for long texts -- very fast generation time (~35-49x RTF) +- simple audio generation web ui utility -## Quick Start +
+OpenAI-Compatible Speech Endpoint + +The service can be accessed through either the API endpoints or the Gradio web interface. 1. Install prerequisites: - - Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) - - Install [Git](https://git-scm.com/downloads) (or download and extract zip) + - Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads) + - Clone and start the service: + ```bash + git clone https://github.com/remsky/Kokoro-FastAPI.git + cd Kokoro-FastAPI + docker compose up --build + ``` +2. Run locally as an OpenAI-Compatible Speech Endpoint + ```python + from openai import OpenAI + client = OpenAI( + base_url="http://localhost:8880", + api_key="not-needed" + ) -2. Clone and start the service: -```bash -# Clone repository -git clone https://github.com/remsky/Kokoro-FastAPI.git -cd Kokoro-FastAPI + response = client.audio.speech.create( + model="kokoro", + voice="af_bella", + input="Hello world!", + response_format="mp3" + ) + response.stream_to_file("output.mp3") + ``` -# For GPU acceleration (requires NVIDIA GPU): -docker compose up --build + or visit http://localhost:7860 +

+ Voice Analysis Comparison +

+
+
+OpenAI-Compatible Speech Endpoint -# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU): -docker compose -f docker-compose.cpu.yml up --build -``` -Quick tests (run from another terminal): -```bash -# Test OpenAI Compatibility -python examples/test_openai_tts.py -# Test all available voices -python examples/test_all_voices.py -``` - -## OpenAI-Compatible API ```python # Using OpenAI's Python library from openai import OpenAI @@ -77,16 +90,26 @@ with open("output.mp3", "wb") as f: f.write(response.content) ``` -## Voice Combination +Quick tests (run from another terminal): +```bash +python examples/test_openai_tts.py # Test OpenAI Compatibility +python examples/test_all_voices.py # Test all available voices +``` +
+ +
+Voice Combination Combine voices and generate audio: ```python import requests +response = requests.get("http://localhost:8880/v1/audio/voices") +voices = response.json()["voices"] -# Create combined voice (saved locally on server) +# Create combined voice (saves locally on server) response = requests.post( "http://localhost:8880/v1/audio/voices/combine", - json=["af_bella", "af_sarah"] + json=[voices[0], voices[1]] ) combined_voice = response.json()["voice"] @@ -100,8 +123,27 @@ response = requests.post( } ) ``` +

+ Voice Analysis Comparison +

+
-## Performance Benchmarks +
+Gradio Web Utility + +Access the interactive web UI at http://localhost:7860 after starting the service. Features include: +- Voice/format/speed selection +- Audio playback and download +- Text file or direct input + +If you only want the API, just comment out everything in the docker-compose.yml under and including `gradio-ui` + +Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added +
+ + +
+Performance Benchmarks Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on: - Windows 11 Home w/ WSL2 @@ -119,10 +161,22 @@ Benchmarking was performed on generation via the local API using text lengths up Key Performance Metrics: - Realtime Factor: Ranges between 35-49x (generation time to output audio length) - Average Processing Rate: 137.67 tokens/second (cl100k_base) +
+
+GPU Vs. CPU -## Features +```bash +# GPU: Requires NVIDIA GPU with CUDA 12.1 support +docker compose up --build -- OpenAI-compatible API endpoints +# CPU: ~10x slower than GPU inference +docker compose -f docker-compose.cpu.yml up --build +``` +
+
+Features + +- OpenAI-compatible API endpoints (with optional Gradio Web UI) - GPU-accelerated inference (if desired) - Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented) - Natural Boundary Detection: @@ -131,19 +185,21 @@ Key Performance Metrics: - Averages model weights of any existing voicepacks - Saves generated voicepacks for future use -

- Voice Analysis Comparison -

+ *Note: CPU Inference is currently a very basic implementation, and not heavily tested* +
-## Model +
+Model This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace. Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects. +
-## License +
+License This project is licensed under the Apache License 2.0 - see below for details: @@ -152,3 +208,4 @@ This project is licensed under the Apache License 2.0 - see below for details: - The inference code adapted from StyleTTS2 is MIT licensed The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0 +
diff --git a/docker-compose.yml b/docker-compose.yml index 565d158..2722208 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,14 +46,14 @@ services: model-fetcher: condition: service_healthy - # # Gradio UI service - # gradio-ui: - # build: - # context: ./ui - # ports: - # - "7860:7860" - # volumes: - # - ./ui/data:/app/ui/data - # - ./ui/app.py:/app/app.py # Mount app.py for hot reload - # environment: - # - GRADIO_WATCH=True # Enable hot reloading + # Gradio UI service [Comment out everything below if you don't need it] + gradio-ui: + build: + context: ./ui + ports: + - "7860:7860" + volumes: + - ./ui/data:/app/ui/data + - ./ui/app.py:/app/app.py # Mount app.py for hot reload + environment: + - GRADIO_WATCH=True # Enable hot reloading diff --git a/pytest.ini b/pytest.ini index 3bcd461..47be4b5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -testpaths = api/tests +testpaths = api/tests ui/tests python_files = test_*.py -addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc +addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc pythonpath = . diff --git a/ui/GradioScreenShot.png b/ui/GradioScreenShot.png new file mode 100644 index 0000000..77af6b3 Binary files /dev/null and b/ui/GradioScreenShot.png differ diff --git a/ui/lib/api.py b/ui/lib/api.py index 20e8b1d..a9c6a19 100644 --- a/ui/lib/api.py +++ b/ui/lib/api.py @@ -7,7 +7,11 @@ from .config import API_URL, OUTPUTS_DIR def check_api_status() -> Tuple[bool, List[str]]: """Check TTS service status and get available voices.""" try: - response = requests.get(f"{API_URL}/v1/audio/voices", timeout=5) + # Use a longer timeout during startup + response = requests.get( + f"{API_URL}/v1/audio/voices", + timeout=30 # Increased timeout for initial startup period + ) response.raise_for_status() voices = response.json().get("voices", []) if voices: @@ -15,7 +19,10 @@ def check_api_status() -> Tuple[bool, List[str]]: print("No voices found in response") return False, [] except requests.exceptions.Timeout: - print("API request timed out") + print("API request timed out (waiting for service startup)") + return False, [] + except requests.exceptions.ConnectionError as e: + print(f"Connection error (service may be starting up): {str(e)}") return False, [] except requests.exceptions.RequestException as e: print(f"API request failed: {str(e)}") diff --git a/ui/lib/components/__init__.py b/ui/lib/components/__init__.py index 637ee14..0d66be3 100644 --- a/ui/lib/components/__init__.py +++ b/ui/lib/components/__init__.py @@ -2,4 +2,4 @@ from .input import create_input_column from .model import create_model_column from .output import create_output_column -__all__ = ['create_input_column', 'create_model_column', 'create_output_column'] +__all__ = ["create_input_column", "create_model_column", "create_output_column"] diff --git a/ui/lib/components/input.py b/ui/lib/components/input.py index a80ecc9..2644060 100644 --- a/ui/lib/components/input.py +++ b/ui/lib/components/input.py @@ -6,6 +6,8 @@ def create_input_column() -> Tuple[gr.Column, dict]: """Create the input column with text input and file handling.""" with gr.Column(scale=1) as col: with gr.Tabs() as tabs: + # Set first tab as selected by default + tabs.selected = 0 # Direct Input Tab with gr.TabItem("Direct Input"): text_input = gr.Textbox( @@ -13,6 +15,11 @@ def create_input_column() -> Tuple[gr.Column, dict]: placeholder="Enter text here...", lines=4 ) + text_submit = gr.Button( + "Generate Speech", + variant="primary", + size="lg" + ) # File Input Tab with gr.TabItem("From File"): @@ -34,13 +41,28 @@ def create_input_column() -> Tuple[gr.Column, dict]: interactive=False, lines=4 ) + + with gr.Row(): + file_submit = gr.Button( + "Generate Speech", + variant="primary", + size="lg" + ) + clear_files = gr.Button( + "Clear Files", + variant="secondary", + size="lg" + ) components = { "tabs": tabs, "text_input": text_input, "file_select": input_files_list, "file_upload": file_upload, - "file_preview": file_preview + "file_preview": file_preview, + "text_submit": text_submit, + "file_submit": file_submit, + "clear_files": clear_files } return col, components diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py index 41bcbfc..3b7ae96 100644 --- a/ui/lib/components/model.py +++ b/ui/lib/components/model.py @@ -10,10 +10,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di with gr.Column(scale=1) as col: gr.Markdown("### Model Settings") - # Status button with embedded status - is_available, _ = api.check_api_status() + # Status button starts in waiting state status_btn = gr.Button( - f"Checking TTS Service: {'Available' if is_available else 'Not Yet Available'}", + "⌛ TTS Service: Waiting for Service...", variant="secondary" ) @@ -35,19 +34,12 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di step=0.1, label="Speed" ) - - submit_btn = gr.Button( - "Generate Speech", - variant="primary", - size="lg" - ) components = { "status_btn": status_btn, "voice": voice_input, "format": format_input, - "speed": speed_input, - "submit": submit_btn + "speed": speed_input } return col, components diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py index ff951fd..8ef4640 100644 --- a/ui/lib/components/output.py +++ b/ui/lib/components/output.py @@ -26,12 +26,15 @@ def create_output_column() -> Tuple[gr.Column, dict]: type="filepath", visible=False ) + + clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary") components = { "audio_output": audio_output, "output_files": output_files, "play_btn": play_btn, - "selected_audio": selected_audio + "selected_audio": selected_audio, + "clear_outputs": clear_outputs } return col, components diff --git a/ui/lib/files.py b/ui/lib/files.py index 66d44ce..98867f3 100644 --- a/ui/lib/files.py +++ b/ui/lib/files.py @@ -56,6 +56,30 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: print(f"Error saving file: {e}") return None +def delete_all_input_files() -> bool: + """Delete all files from the inputs directory. Returns True if successful.""" + try: + for filename in os.listdir(INPUTS_DIR): + if filename.endswith('.txt'): + file_path = os.path.join(INPUTS_DIR, filename) + os.remove(file_path) + return True + except Exception as e: + print(f"Error deleting input files: {e}") + return False + +def delete_all_output_files() -> bool: + """Delete all audio files from the outputs directory. Returns True if successful.""" + try: + for filename in os.listdir(OUTPUTS_DIR): + if any(filename.endswith(ext) for ext in AUDIO_FORMATS): + file_path = os.path.join(OUTPUTS_DIR, filename) + os.remove(file_path) + return True + except Exception as e: + print(f"Error deleting output files: {e}") + return False + def process_uploaded_file(file_path: str) -> bool: """Save uploaded file to inputs directory. Returns True if successful.""" if not file_path: diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py index bcc15d7..94c9574 100644 --- a/ui/lib/handlers.py +++ b/ui/lib/handlers.py @@ -7,19 +7,40 @@ def setup_event_handlers(components: dict): """Set up all event handlers for the UI components.""" def refresh_status(): - is_available, voices = api.check_api_status() - status = "Available" if is_available else "Unavailable" - btn_text = f"🔄 TTS Service: {status}" - - if is_available and voices: - return { - components["model"]["status_btn"]: gr.update(value=btn_text), - components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None) - } - return { - components["model"]["status_btn"]: gr.update(value=btn_text), - components["model"]["voice"]: gr.update(choices=[], value=None) - } + try: + is_available, voices = api.check_api_status() + status = "Available" if is_available else "Waiting for Service..." + + if is_available and voices: + # Preserve current voice selection if it exists and is still valid + current_voice = components["model"]["voice"].value + default_voice = current_voice if current_voice in voices else voices[0] + return [ + gr.update( + value=f"🔄 TTS Service: {status}", + interactive=True, + variant="secondary" + ), + gr.update(choices=voices, value=default_voice) + ] + return [ + gr.update( + value=f"⌛ TTS Service: {status}", + interactive=True, + variant="secondary" + ), + gr.update(choices=[], value=None) + ] + except Exception as e: + print(f"Error in refresh status: {str(e)}") + return [ + gr.update( + value="❌ TTS Service: Connection Error", + interactive=True, + variant="secondary" + ), + gr.update(choices=[], value=None) + ] def handle_file_select(filename): if filename: @@ -56,45 +77,95 @@ def setup_event_handlers(components: dict): return gr.update(choices=files.list_input_files()) - def generate_speech(text, selected_file, voice, format, speed): + def generate_from_text(text, voice, format, speed): + """Generate speech from direct text input""" is_available, _ = api.check_api_status() if not is_available: gr.Warning("TTS Service is currently unavailable") - return { - components["output"]["audio_output"]: None, - components["output"]["output_files"]: gr.update(choices=files.list_output_files()) - } - - # Use text input if provided, otherwise use file content - if text and text.strip(): - files.save_text(text) - final_text = text - elif selected_file: - final_text = files.read_text_file(selected_file) - else: - gr.Warning("Please enter text or select a file") - return { - components["output"]["audio_output"]: None, - components["output"]["output_files"]: gr.update(choices=files.list_output_files()) - } - - result = api.text_to_speech(final_text, voice, format, speed) + return [ + None, + gr.update(choices=files.list_output_files()) + ] + + if not text or not text.strip(): + gr.Warning("Please enter text in the input box") + return [ + None, + gr.update(choices=files.list_output_files()) + ] + + files.save_text(text) + result = api.text_to_speech(text, voice, format, speed) if result is None: gr.Warning("Failed to generate speech. Please try again.") - return { - components["output"]["audio_output"]: None, - components["output"]["output_files"]: gr.update(choices=files.list_output_files()) - } + return [ + None, + gr.update(choices=files.list_output_files()) + ] - return { - components["output"]["audio_output"]: result, - components["output"]["output_files"]: gr.update(choices=files.list_output_files(), value=os.path.basename(result)) - } + return [ + result, + gr.update(choices=files.list_output_files(), value=os.path.basename(result)) + ] + + def generate_from_file(selected_file, voice, format, speed): + """Generate speech from selected file""" + is_available, _ = api.check_api_status() + if not is_available: + gr.Warning("TTS Service is currently unavailable") + return [ + None, + gr.update(choices=files.list_output_files()) + ] + + if not selected_file: + gr.Warning("Please select a file") + return [ + None, + gr.update(choices=files.list_output_files()) + ] + + text = files.read_text_file(selected_file) + result = api.text_to_speech(text, voice, format, speed) + if result is None: + gr.Warning("Failed to generate speech. Please try again.") + return [ + None, + gr.update(choices=files.list_output_files()) + ] + + return [ + result, + gr.update(choices=files.list_output_files(), value=os.path.basename(result)) + ] def play_selected(file_path): if file_path and os.path.exists(file_path): return gr.update(value=file_path, visible=True) return gr.update(visible=False) + + def clear_files(voice, format, speed): + """Delete all input files and clear UI components while preserving model settings""" + files.delete_all_input_files() + return [ + gr.update(value=None, choices=[]), # file_select + None, # file_upload + gr.update(value=""), # file_preview + None, # audio_output + gr.update(choices=files.list_output_files()), # output_files + gr.update(value=voice), # voice + gr.update(value=format), # format + gr.update(value=speed) # speed + ] + + def clear_outputs(): + """Delete all output audio files and clear audio components""" + files.delete_all_output_files() + return [ + None, # audio_output + gr.update(choices=[], value=None), # output_files + gr.update(visible=False) # selected_audio + ] # Connect event handlers components["model"]["status_btn"].click( @@ -123,10 +194,54 @@ def setup_event_handlers(components: dict): outputs=[components["output"]["selected_audio"]] ) - components["model"]["submit"].click( - fn=generate_speech, + # Connect clear files button + components["input"]["clear_files"].click( + fn=clear_files, + inputs=[ + components["model"]["voice"], + components["model"]["format"], + components["model"]["speed"] + ], + outputs=[ + components["input"]["file_select"], + components["input"]["file_upload"], + components["input"]["file_preview"], + components["output"]["audio_output"], + components["output"]["output_files"], + components["model"]["voice"], + components["model"]["format"], + components["model"]["speed"] + ] + ) + + # Connect submit buttons for each tab + components["input"]["text_submit"].click( + fn=generate_from_text, inputs=[ components["input"]["text_input"], + components["model"]["voice"], + components["model"]["format"], + components["model"]["speed"] + ], + outputs=[ + components["output"]["audio_output"], + components["output"]["output_files"] + ] + ) + + # Connect clear outputs button + components["output"]["clear_outputs"].click( + fn=clear_outputs, + outputs=[ + components["output"]["audio_output"], + components["output"]["output_files"], + components["output"]["selected_audio"] + ] + ) + + components["input"]["file_submit"].click( + fn=generate_from_file, + inputs=[ components["input"]["file_select"], components["model"]["voice"], components["model"]["format"], diff --git a/ui/lib/interface.py b/ui/lib/interface.py index cfdada4..5361217 100644 --- a/ui/lib/interface.py +++ b/ui/lib/interface.py @@ -5,8 +5,8 @@ from .handlers import setup_event_handlers def create_interface(): """Create the main Gradio interface.""" - # Initial status check - is_available, available_voices = api.check_api_status() + # Skip initial status check - let the timer handle it + is_available, available_voices = False, [] with gr.Blocks( title="Kokoro TTS Demo", @@ -36,19 +36,55 @@ def create_interface(): # Add periodic status check with Timer def update_status(): - is_available, voices = api.check_api_status() - status = "Available" if is_available else "Unavailable" - return { - components["model"]["status_btn"]: gr.update(value=f"🔄 TTS Service: {status}"), - components["model"]["voice"]: gr.update(choices=voices, value=voices[0] if voices else None) - } + try: + is_available, voices = api.check_api_status() + status = "Available" if is_available else "Waiting for Service..." + + if is_available and voices: + # Service is available, update UI and stop timer + current_voice = components["model"]["voice"].value + default_voice = current_voice if current_voice in voices else voices[0] + # Return values in same order as outputs list + return [ + gr.update( + value=f"🔄 TTS Service: {status}", + interactive=True, + variant="secondary" + ), + gr.update(choices=voices, value=default_voice), + gr.update(active=False) # Stop timer + ] + + # Service not available yet, keep checking + return [ + gr.update( + value=f"⌛ TTS Service: {status}", + interactive=True, + variant="secondary" + ), + gr.update(choices=[], value=None), + gr.update(active=True) + ] + except Exception as e: + print(f"Error in status update: {str(e)}") + # On error, keep the timer running but show error state + return [ + gr.update( + value="❌ TTS Service: Connection Error", + interactive=True, + variant="secondary" + ), + gr.update(choices=[], value=None), + gr.update(active=True) + ] - timer = gr.Timer(10, active=True) # Check every 10 seconds + timer = gr.Timer(value=5) # Check every 5 seconds timer.tick( fn=update_status, outputs=[ components["model"]["status_btn"], - components["model"]["voice"] + components["model"]["voice"], + timer ] ) diff --git a/ui/tests/conftest.py b/ui/tests/conftest.py new file mode 100644 index 0000000..05ae58d --- /dev/null +++ b/ui/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +import gradio as gr + + +@pytest.fixture +def mock_gr_context(): + """Provides a context for testing Gradio components""" + with gr.Blocks(): + yield diff --git a/ui/tests/test_api.py b/ui/tests/test_api.py new file mode 100644 index 0000000..c9b37db --- /dev/null +++ b/ui/tests/test_api.py @@ -0,0 +1,129 @@ +import pytest +import requests +from unittest.mock import patch, mock_open +from ui.lib import api + + +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, json_data, status_code=200, content=b"audio data"): + self._json = json_data + self.status_code = status_code + self.content = content + + def json(self): + return self._json + + def raise_for_status(self): + if self.status_code != 200: + raise requests.exceptions.HTTPError(f"HTTP {self.status_code}") + + return MockResponse + + +def test_check_api_status_success(mock_response): + """Test successful API status check""" + mock_data = {"voices": ["voice1", "voice2"]} + with patch("requests.get", return_value=mock_response(mock_data)): + status, voices = api.check_api_status() + assert status is True + assert voices == ["voice1", "voice2"] + + +def test_check_api_status_no_voices(mock_response): + """Test API response with no voices""" + with patch("requests.get", return_value=mock_response({"voices": []})): + status, voices = api.check_api_status() + assert status is False + assert voices == [] + + +def test_check_api_status_timeout(): + """Test API timeout""" + with patch("requests.get", side_effect=requests.exceptions.Timeout): + status, voices = api.check_api_status() + assert status is False + assert voices == [] + + +def test_check_api_status_connection_error(): + """Test API connection error""" + with patch("requests.get", side_effect=requests.exceptions.ConnectionError): + status, voices = api.check_api_status() + assert status is False + assert voices == [] + + +def test_text_to_speech_success(mock_response, tmp_path): + """Test successful speech generation""" + with patch("requests.post", return_value=mock_response({})), \ + patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ + patch("builtins.open", mock_open()) as mock_file: + + result = api.text_to_speech("test text", "voice1", "mp3", 1.0) + + assert result is not None + assert "output_" in result + assert result.endswith(".mp3") + mock_file.assert_called_once() + + +def test_text_to_speech_empty_text(): + """Test speech generation with empty text""" + result = api.text_to_speech("", "voice1", "mp3", 1.0) + assert result is None + + +def test_text_to_speech_timeout(): + """Test speech generation timeout""" + with patch("requests.post", side_effect=requests.exceptions.Timeout): + result = api.text_to_speech("test", "voice1", "mp3", 1.0) + assert result is None + + +def test_text_to_speech_request_error(): + """Test speech generation request error""" + with patch("requests.post", side_effect=requests.exceptions.RequestException): + result = api.text_to_speech("test", "voice1", "mp3", 1.0) + assert result is None + + +def test_get_status_html_available(): + """Test status HTML generation for available service""" + html = api.get_status_html(True) + assert "green" in html + assert "Available" in html + + +def test_get_status_html_unavailable(): + """Test status HTML generation for unavailable service""" + html = api.get_status_html(False) + assert "red" in html + assert "Unavailable" in html + + +def test_text_to_speech_api_params(mock_response, tmp_path): + """Test correct API parameters are sent""" + with patch("requests.post") as mock_post, \ + patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ + patch("builtins.open", mock_open()): + + mock_post.return_value = mock_response({}) + api.text_to_speech("test text", "voice1", "mp3", 1.5) + + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + + # Check request body + assert kwargs["json"] == { + "model": "kokoro", + "input": "test text", + "voice": "voice1", + "response_format": "mp3", + "speed": 1.5 + } + + # Check headers and timeout + assert kwargs["headers"] == {"Content-Type": "application/json"} + assert kwargs["timeout"] == 300 diff --git a/ui/tests/test_components.py b/ui/tests/test_components.py new file mode 100644 index 0000000..9ddb1ad --- /dev/null +++ b/ui/tests/test_components.py @@ -0,0 +1,116 @@ +import pytest +import gradio as gr +from ui.lib.components.model import create_model_column +from ui.lib.components.output import create_output_column +from ui.lib.config import AUDIO_FORMATS + + +def test_create_model_column_structure(): + """Test that create_model_column returns the expected structure""" + voice_ids = ["voice1", "voice2"] + column, components = create_model_column(voice_ids) + + # Test return types + assert isinstance(column, gr.Column) + assert isinstance(components, dict) + + # Test expected components presence + expected_components = { + "status_btn", + "voice", + "format", + "speed" + } + assert set(components.keys()) == expected_components + + # Test component types + assert isinstance(components["status_btn"], gr.Button) + assert isinstance(components["voice"], gr.Dropdown) + assert isinstance(components["format"], gr.Dropdown) + assert isinstance(components["speed"], gr.Slider) + + +def test_model_column_default_values(): + """Test the default values of model column components""" + voice_ids = ["voice1", "voice2"] + _, components = create_model_column(voice_ids) + + # Test voice dropdown + # Gradio Dropdown converts choices to (value, label) tuples + expected_choices = [(voice_id, voice_id) for voice_id in voice_ids] + assert components["voice"].choices == expected_choices + # Value is not converted to tuple format for the value property + assert components["voice"].value == voice_ids[0] + assert components["voice"].interactive is True + + # Test format dropdown + # Gradio Dropdown converts choices to (value, label) tuples + expected_format_choices = [(fmt, fmt) for fmt in AUDIO_FORMATS] + assert components["format"].choices == expected_format_choices + assert components["format"].value == "mp3" + + # Test speed slider + assert components["speed"].minimum == 0.5 + assert components["speed"].maximum == 2.0 + assert components["speed"].value == 1.0 + assert components["speed"].step == 0.1 + + +def test_model_column_no_voices(): + """Test model column creation with no voice IDs""" + _, components = create_model_column() + + assert components["voice"].choices == [] + assert components["voice"].value is None + + +def test_create_output_column_structure(): + """Test that create_output_column returns the expected structure""" + column, components = create_output_column() + + # Test return types + assert isinstance(column, gr.Column) + assert isinstance(components, dict) + + # Test expected components presence + expected_components = { + "audio_output", + "output_files", + "play_btn", + "selected_audio", + "clear_outputs" + } + assert set(components.keys()) == expected_components + + # Test component types + assert isinstance(components["audio_output"], gr.Audio) + assert isinstance(components["output_files"], gr.Dropdown) + assert isinstance(components["play_btn"], gr.Button) + assert isinstance(components["selected_audio"], gr.Audio) + assert isinstance(components["clear_outputs"], gr.Button) + + +def test_output_column_configuration(): + """Test the configuration of output column components""" + _, components = create_output_column() + + # Test audio output configuration + assert components["audio_output"].label == "Generated Speech" + assert components["audio_output"].type == "filepath" + + # Test output files dropdown + assert components["output_files"].label == "Previous Outputs" + assert components["output_files"].allow_custom_value is False + + # Test play button + assert components["play_btn"].value == "▶️ Play Selected" + assert components["play_btn"].size == "sm" + + # Test selected audio configuration + assert components["selected_audio"].label == "Selected Output" + assert components["selected_audio"].type == "filepath" + assert components["selected_audio"].visible is False + + # Test clear outputs button + assert components["clear_outputs"].size == "sm" + assert components["clear_outputs"].variant == "secondary" diff --git a/ui/tests/test_files.py b/ui/tests/test_files.py new file mode 100644 index 0000000..aaa0fe8 --- /dev/null +++ b/ui/tests/test_files.py @@ -0,0 +1,195 @@ +import os +import pytest +from unittest.mock import patch +from ui.lib import files +from ui.lib.config import AUDIO_FORMATS + + +@pytest.fixture +def mock_dirs(tmp_path): + """Create temporary input and output directories""" + inputs_dir = tmp_path / "inputs" + outputs_dir = tmp_path / "outputs" + inputs_dir.mkdir() + outputs_dir.mkdir() + + with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch( + "ui.lib.files.OUTPUTS_DIR", str(outputs_dir) + ): + yield inputs_dir, outputs_dir + + +def test_list_input_files_empty(mock_dirs): + """Test listing input files from empty directory""" + assert files.list_input_files() == [] + + +def test_list_input_files(mock_dirs): + """Test listing input files with various files""" + inputs_dir, _ = mock_dirs + + # Create test files + (inputs_dir / "test1.txt").write_text("content1") + (inputs_dir / "test2.txt").write_text("content2") + (inputs_dir / "nottext.pdf").write_text("should not be listed") + + result = files.list_input_files() + assert len(result) == 2 + assert "test1.txt" in result + assert "test2.txt" in result + assert "nottext.pdf" not in result + + +def test_list_output_files_empty(mock_dirs): + """Test listing output files from empty directory""" + assert files.list_output_files() == [] + + +def test_list_output_files(mock_dirs): + """Test listing output files with various formats""" + _, outputs_dir = mock_dirs + + # Create test files for each format + for fmt in AUDIO_FORMATS: + (outputs_dir / f"test.{fmt}").write_text("dummy content") + (outputs_dir / "test.txt").write_text("should not be listed") + + result = files.list_output_files() + assert len(result) == len(AUDIO_FORMATS) + for fmt in AUDIO_FORMATS: + assert any(f".{fmt}" in file for file in result) + + +def test_read_text_file_empty_filename(mock_dirs): + """Test reading with empty filename""" + assert files.read_text_file("") == "" + + +def test_read_text_file_nonexistent(mock_dirs): + """Test reading nonexistent file""" + assert files.read_text_file("nonexistent.txt") == "" + + +def test_read_text_file_success(mock_dirs): + """Test successful file reading""" + inputs_dir, _ = mock_dirs + content = "Test content\nMultiple lines" + (inputs_dir / "test.txt").write_text(content) + + assert files.read_text_file("test.txt") == content + + +def test_save_text_empty(mock_dirs): + """Test saving empty text""" + assert files.save_text("") is None + assert files.save_text(" ") is None + + +def test_save_text_auto_filename(mock_dirs): + """Test saving text with auto-generated filename""" + inputs_dir, _ = mock_dirs + + # First save + filename1 = files.save_text("content1") + assert filename1 == "input_1.txt" + assert (inputs_dir / filename1).read_text() == "content1" + + # Second save + filename2 = files.save_text("content2") + assert filename2 == "input_2.txt" + assert (inputs_dir / filename2).read_text() == "content2" + + +def test_save_text_custom_filename(mock_dirs): + """Test saving text with custom filename""" + inputs_dir, _ = mock_dirs + + filename = files.save_text("content", "custom.txt") + assert filename == "custom.txt" + assert (inputs_dir / filename).read_text() == "content" + + +def test_save_text_duplicate_filename(mock_dirs): + """Test saving text with duplicate filename""" + inputs_dir, _ = mock_dirs + + # First save + filename1 = files.save_text("content1", "test.txt") + assert filename1 == "test.txt" + + # Save with same filename + filename2 = files.save_text("content2", "test.txt") + assert filename2 == "test_1.txt" + + assert (inputs_dir / "test.txt").read_text() == "content1" + assert (inputs_dir / "test_1.txt").read_text() == "content2" + + +def test_delete_all_input_files(mock_dirs): + """Test deleting all input files""" + inputs_dir, _ = mock_dirs + + # Create test files + (inputs_dir / "test1.txt").write_text("content1") + (inputs_dir / "test2.txt").write_text("content2") + (inputs_dir / "keep.pdf").write_text("should not be deleted") + + assert files.delete_all_input_files() is True + remaining_files = list(inputs_dir.iterdir()) + assert len(remaining_files) == 1 + assert remaining_files[0].name == "keep.pdf" + + +def test_delete_all_output_files(mock_dirs): + """Test deleting all output files""" + _, outputs_dir = mock_dirs + + # Create test files + for fmt in AUDIO_FORMATS: + (outputs_dir / f"test.{fmt}").write_text("dummy content") + (outputs_dir / "keep.txt").write_text("should not be deleted") + + assert files.delete_all_output_files() is True + remaining_files = list(outputs_dir.iterdir()) + assert len(remaining_files) == 1 + assert remaining_files[0].name == "keep.txt" + + +def test_process_uploaded_file_empty_path(mock_dirs): + """Test processing empty file path""" + assert files.process_uploaded_file("") is False + + +def test_process_uploaded_file_invalid_extension(mock_dirs, tmp_path): + """Test processing file with invalid extension""" + test_file = tmp_path / "test.pdf" + test_file.write_text("content") + assert files.process_uploaded_file(str(test_file)) is False + + +def test_process_uploaded_file_success(mock_dirs, tmp_path): + """Test successful file upload processing""" + inputs_dir, _ = mock_dirs + + # Create source file + source_file = tmp_path / "test.txt" + source_file.write_text("test content") + + assert files.process_uploaded_file(str(source_file)) is True + assert (inputs_dir / "test.txt").read_text() == "test content" + + +def test_process_uploaded_file_duplicate(mock_dirs, tmp_path): + """Test processing file with duplicate name""" + inputs_dir, _ = mock_dirs + + # Create existing file + (inputs_dir / "test.txt").write_text("existing content") + + # Create source file + source_file = tmp_path / "test.txt" + source_file.write_text("new content") + + assert files.process_uploaded_file(str(source_file)) is True + assert (inputs_dir / "test.txt").read_text() == "existing content" + assert (inputs_dir / "test_1.txt").read_text() == "new content" diff --git a/ui/tests/test_handlers.py b/ui/tests/test_handlers.py new file mode 100644 index 0000000..86a71b0 --- /dev/null +++ b/ui/tests/test_handlers.py @@ -0,0 +1,4 @@ +""" +Drop all tests for now. The Gradio event system is too complex to test properly. +We'll need to find a better way to test the UI functionality. +""" diff --git a/ui/tests/test_input.py b/ui/tests/test_input.py new file mode 100644 index 0000000..807a483 --- /dev/null +++ b/ui/tests/test_input.py @@ -0,0 +1,74 @@ +import pytest +import gradio as gr +from ui.lib.components.input import create_input_column + + +def test_create_input_column_structure(): + """Test that create_input_column returns the expected structure""" + column, components = create_input_column() + + # Test the return types + assert isinstance(column, gr.Column) + assert isinstance(components, dict) + + # Test that all expected components are present + expected_components = { + "tabs", + "text_input", + "file_select", + "file_upload", + "file_preview", + "text_submit", + "file_submit", + "clear_files", + } + assert set(components.keys()) == expected_components + + # Test component types + assert isinstance(components["tabs"], gr.Tabs) + assert isinstance(components["text_input"], gr.Textbox) + assert isinstance(components["file_select"], gr.Dropdown) + assert isinstance(components["file_upload"], gr.File) + assert isinstance(components["file_preview"], gr.Textbox) + assert isinstance(components["text_submit"], gr.Button) + assert isinstance(components["file_submit"], gr.Button) + assert isinstance(components["clear_files"], gr.Button) + + +def test_text_input_configuration(): + """Test the text input component configuration""" + _, components = create_input_column() + text_input = components["text_input"] + + assert text_input.label == "Text to speak" + assert text_input.placeholder == "Enter text here..." + assert text_input.lines == 4 + + +def test_file_upload_configuration(): + """Test the file upload component configuration""" + _, components = create_input_column() + file_upload = components["file_upload"] + + assert file_upload.label == "Upload Text File (.txt)" + assert file_upload.file_types == [".txt"] + + +def test_button_configurations(): + """Test the button configurations""" + _, components = create_input_column() + + # Test text submit button + assert components["text_submit"].value == "Generate Speech" + assert components["text_submit"].variant == "primary" + assert components["text_submit"].size == "lg" + + # Test file submit button + assert components["file_submit"].value == "Generate Speech" + assert components["file_submit"].variant == "primary" + assert components["file_submit"].size == "lg" + + # Test clear files button + assert components["clear_files"].value == "Clear Files" + assert components["clear_files"].variant == "secondary" + assert components["clear_files"].size == "lg" diff --git a/ui/tests/test_interface.py b/ui/tests/test_interface.py new file mode 100644 index 0000000..550591f --- /dev/null +++ b/ui/tests/test_interface.py @@ -0,0 +1,139 @@ +import pytest +import gradio as gr +from unittest.mock import patch, MagicMock, PropertyMock +from ui.lib.interface import create_interface + + +@pytest.fixture +def mock_timer(): + """Create a mock timer with events property""" + class MockEvent: + def __init__(self, fn): + self.fn = fn + + class MockTimer: + def __init__(self): + self._fn = None + self.value = 5 + + @property + def events(self): + return [MockEvent(self._fn)] if self._fn else [] + + def tick(self, fn, outputs): + self._fn = fn + + return MockTimer() + + +def test_create_interface_structure(): + """Test the basic structure of the created interface""" + with patch("ui.lib.api.check_api_status", return_value=(False, [])): + demo = create_interface() + + # Test interface type and theme + assert isinstance(demo, gr.Blocks) + assert demo.title == "Kokoro TTS Demo" + assert isinstance(demo.theme, gr.themes.Monochrome) + + +def test_interface_html_links(): + """Test that HTML links are properly configured""" + with patch("ui.lib.api.check_api_status", return_value=(False, [])): + demo = create_interface() + + # Find HTML component + html_components = [ + comp for comp in demo.blocks.values() + if isinstance(comp, gr.HTML) + ] + assert len(html_components) > 0 + html = html_components[0] + + # Check for required links + assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value + assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value + assert "Kokoro-82M HF Repo" in html.value + assert "Kokoro-FastAPI Repo" in html.value + + +def test_update_status_available(mock_timer): + """Test status update when service is available""" + voices = ["voice1", "voice2"] + with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \ + patch("gradio.Timer", return_value=mock_timer): + demo = create_interface() + + # Get the update function + update_fn = mock_timer.events[0].fn + + # Test update with available service + updates = update_fn() + + assert "Available" in updates[0]["value"] + assert updates[1]["choices"] == voices + assert updates[1]["value"] == voices[0] + assert updates[2]["active"] is False # Timer should stop + + +def test_update_status_unavailable(mock_timer): + """Test status update when service is unavailable""" + with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ + patch("gradio.Timer", return_value=mock_timer): + demo = create_interface() + update_fn = mock_timer.events[0].fn + + updates = update_fn() + + assert "Waiting for Service" in updates[0]["value"] + assert updates[1]["choices"] == [] + assert updates[1]["value"] is None + assert updates[2]["active"] is True # Timer should continue + + +def test_update_status_error(mock_timer): + """Test status update when an error occurs""" + with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \ + patch("gradio.Timer", return_value=mock_timer): + demo = create_interface() + update_fn = mock_timer.events[0].fn + + updates = update_fn() + + assert "Connection Error" in updates[0]["value"] + assert updates[1]["choices"] == [] + assert updates[1]["value"] is None + assert updates[2]["active"] is True # Timer should continue + + +def test_timer_configuration(mock_timer): + """Test timer configuration""" + with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ + patch("gradio.Timer", return_value=mock_timer): + demo = create_interface() + + assert mock_timer.value == 5 # Check interval is 5 seconds + assert len(mock_timer.events) == 1 # Should have one event handler + + +def test_interface_components_presence(): + """Test that all required components are present""" + with patch("ui.lib.api.check_api_status", return_value=(False, [])): + demo = create_interface() + + # Check for main component sections + components = { + comp.label for comp in demo.blocks.values() + if hasattr(comp, 'label') and comp.label + } + + required_components = { + "Text to speak", + "Voice", + "Audio Format", + "Speed", + "Generated Speech", + "Previous Outputs" + } + + assert required_components.issubset(components)