From 8ed2f2afb6537084f51f68bd8fc6362a5079fad3 Mon Sep 17 00:00:00 2001 From: remsky Date: Sun, 9 Feb 2025 20:55:21 -0700 Subject: [PATCH] Add model listing and retrieval endpoints with tests --- README.md | 4 +- api/src/routers/openai_compatible.py | 93 ++++++++++++++++++++++++++++ api/tests/test_openai_endpoints.py | 43 +++++++++++++ debug.http | 12 ++++ 4 files changed, 150 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 479f966..ad4e512 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@

# _`FastKoko`_ -[![Tests](https://img.shields.io/badge/tests-66%20passed-darkgreen)]() -[![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]() +[![Tests](https://img.shields.io/badge/tests-69%20passed-darkgreen)]() +[![Coverage](https://img.shields.io/badge/coverage-51%25-tan)]() [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index a2506bc..5508d65 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -344,6 +344,99 @@ async def download_audio_file(filename: str): ) +@router.get("/models") +async def list_models(): + """List all available models""" + try: + # Create standard model list + models = [ + { + "id": "tts-1", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + }, + { + "id": "tts-1-hd", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + }, + { + "id": "kokoro", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + } + ] + + return { + "object": "list", + "data": models + } + except Exception as e: + logger.error(f"Error listing models: {str(e)}") + raise HTTPException( + status_code=500, + detail={ + "error": "server_error", + "message": "Failed to retrieve model list", + "type": "server_error", + }, + ) + +@router.get("/models/{model}") +async def retrieve_model(model: str): + """Retrieve a specific model""" + try: + # Define available models + models = { + "tts-1": { + "id": "tts-1", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + }, + "tts-1-hd": { + "id": "tts-1-hd", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + }, + "kokoro": { + "id": "kokoro", + "object": "model", + "created": 1686935002, + "owned_by": "kokoro" + } + } + + # Check if requested model exists + if model not in models: + raise HTTPException( + status_code=404, + detail={ + "error": "model_not_found", + "message": f"Model '{model}' not found", + "type": "invalid_request_error" + } + ) + + # Return the specific model + return models[model] + except HTTPException: + raise + except Exception as e: + logger.error(f"Error retrieving model {model}: {str(e)}") + raise HTTPException( + status_code=500, + detail={ + "error": "server_error", + "message": "Failed to retrieve model information", + "type": "server_error", + }, + ) + @router.get("/audio/voices") async def list_voices(): """List all available voices for text-to-speech""" diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py index 26ac04b..531480f 100644 --- a/api/tests/test_openai_endpoints.py +++ b/api/tests/test_openai_endpoints.py @@ -69,6 +69,49 @@ def test_load_openai_mappings_file_not_found(): assert mappings == {"models": {}, "voices": {}} +def test_list_models(mock_openai_mappings): + """Test listing available models endpoint""" + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert isinstance(data["data"], list) + assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro + + # Verify all expected models are present + model_ids = [model["id"] for model in data["data"]] + assert "tts-1" in model_ids + assert "tts-1-hd" in model_ids + assert "kokoro" in model_ids + + # Verify model format + for model in data["data"]: + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "kokoro" + + +def test_retrieve_model(mock_openai_mappings): + """Test retrieving a specific model endpoint""" + # Test successful model retrieval + response = client.get("/v1/models/tts-1") + assert response.status_code == 200 + data = response.json() + assert data["id"] == "tts-1" + assert data["object"] == "model" + assert data["owned_by"] == "kokoro" + assert "created" in data + + # Test non-existent model + response = client.get("/v1/models/nonexistent-model") + assert response.status_code == 404 + error = response.json() + assert error["detail"]["error"] == "model_not_found" + assert "not found" in error["detail"]["message"] + assert error["detail"]["type"] == "invalid_request_error" + + + @pytest.mark.asyncio async def test_get_tts_service_initialization(): """Test TTSService initialization""" diff --git a/debug.http b/debug.http index be70880..83c8860 100644 --- a/debug.http +++ b/debug.http @@ -14,4 +14,16 @@ Accept: application/json # Shows active ONNX sessions, CUDA stream usage, and session ages # Useful for debugging resource exhaustion issues GET http://localhost:8880/debug/session_pools +Accept: application/json + +### List Available Models +# Returns list of all available models in OpenAI format +# Response includes tts-1, tts-1-hd, and kokoro models +GET http://localhost:8880/v1/models +Accept: application/json + +### Get Specific Model +# Returns same model list as above for compatibility +# Works with any model name (e.g., tts-1, tts-1-hd, kokoro) +GET http://localhost:8880/v1/models/tts-1 Accept: application/json \ No newline at end of file