From 8ed2f2afb6537084f51f68bd8fc6362a5079fad3 Mon Sep 17 00:00:00 2001
From: remsky
Date: Sun, 9 Feb 2025 20:55:21 -0700
Subject: [PATCH] Add model listing and retrieval endpoints with tests
---
README.md | 4 +-
api/src/routers/openai_compatible.py | 93 ++++++++++++++++++++++++++++
api/tests/test_openai_endpoints.py | 43 +++++++++++++
debug.http | 12 ++++
4 files changed, 150 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 479f966..ad4e512 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
# _`FastKoko`_
-[]()
-[]()
+[]()
+[]()
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py
index a2506bc..5508d65 100644
--- a/api/src/routers/openai_compatible.py
+++ b/api/src/routers/openai_compatible.py
@@ -344,6 +344,99 @@ async def download_audio_file(filename: str):
)
+@router.get("/models")
+async def list_models():
+ """List all available models"""
+ try:
+ # Create standard model list
+ models = [
+ {
+ "id": "tts-1",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ },
+ {
+ "id": "tts-1-hd",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ },
+ {
+ "id": "kokoro",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ }
+ ]
+
+ return {
+ "object": "list",
+ "data": models
+ }
+ except Exception as e:
+ logger.error(f"Error listing models: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "server_error",
+ "message": "Failed to retrieve model list",
+ "type": "server_error",
+ },
+ )
+
+@router.get("/models/{model}")
+async def retrieve_model(model: str):
+ """Retrieve a specific model"""
+ try:
+ # Define available models
+ models = {
+ "tts-1": {
+ "id": "tts-1",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ },
+ "tts-1-hd": {
+ "id": "tts-1-hd",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ },
+ "kokoro": {
+ "id": "kokoro",
+ "object": "model",
+ "created": 1686935002,
+ "owned_by": "kokoro"
+ }
+ }
+
+ # Check if requested model exists
+ if model not in models:
+ raise HTTPException(
+ status_code=404,
+ detail={
+ "error": "model_not_found",
+ "message": f"Model '{model}' not found",
+ "type": "invalid_request_error"
+ }
+ )
+
+ # Return the specific model
+ return models[model]
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error retrieving model {model}: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "server_error",
+ "message": "Failed to retrieve model information",
+ "type": "server_error",
+ },
+ )
+
@router.get("/audio/voices")
async def list_voices():
"""List all available voices for text-to-speech"""
diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py
index 26ac04b..531480f 100644
--- a/api/tests/test_openai_endpoints.py
+++ b/api/tests/test_openai_endpoints.py
@@ -69,6 +69,49 @@ def test_load_openai_mappings_file_not_found():
assert mappings == {"models": {}, "voices": {}}
+def test_list_models(mock_openai_mappings):
+ """Test listing available models endpoint"""
+ response = client.get("/v1/models")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["object"] == "list"
+ assert isinstance(data["data"], list)
+ assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
+
+ # Verify all expected models are present
+ model_ids = [model["id"] for model in data["data"]]
+ assert "tts-1" in model_ids
+ assert "tts-1-hd" in model_ids
+ assert "kokoro" in model_ids
+
+ # Verify model format
+ for model in data["data"]:
+ assert model["object"] == "model"
+ assert "created" in model
+ assert model["owned_by"] == "kokoro"
+
+
+def test_retrieve_model(mock_openai_mappings):
+ """Test retrieving a specific model endpoint"""
+ # Test successful model retrieval
+ response = client.get("/v1/models/tts-1")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["id"] == "tts-1"
+ assert data["object"] == "model"
+ assert data["owned_by"] == "kokoro"
+ assert "created" in data
+
+ # Test non-existent model
+ response = client.get("/v1/models/nonexistent-model")
+ assert response.status_code == 404
+ error = response.json()
+ assert error["detail"]["error"] == "model_not_found"
+ assert "not found" in error["detail"]["message"]
+ assert error["detail"]["type"] == "invalid_request_error"
+
+
+
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
diff --git a/debug.http b/debug.http
index be70880..83c8860 100644
--- a/debug.http
+++ b/debug.http
@@ -14,4 +14,16 @@ Accept: application/json
# Shows active ONNX sessions, CUDA stream usage, and session ages
# Useful for debugging resource exhaustion issues
GET http://localhost:8880/debug/session_pools
+Accept: application/json
+
+### List Available Models
+# Returns list of all available models in OpenAI format
+# Response includes tts-1, tts-1-hd, and kokoro models
+GET http://localhost:8880/v1/models
+Accept: application/json
+
+### Get Specific Model
+# Returns same model list as above for compatibility
+# Works with any model name (e.g., tts-1, tts-1-hd, kokoro)
+GET http://localhost:8880/v1/models/tts-1
Accept: application/json
\ No newline at end of file