mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add model listing and retrieval endpoints with tests
This commit is contained in:
parent
d73ed87987
commit
8ed2f2afb6
4 changed files with 150 additions and 2 deletions
|
@ -3,8 +3,8 @@
|
|||
</p>
|
||||
|
||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||
|
|
|
@ -344,6 +344,99 @@ async def download_audio_file(filename: str):
|
|||
)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models():
|
||||
"""List all available models"""
|
||||
try:
|
||||
# Create standard model list
|
||||
models = [
|
||||
{
|
||||
"id": "tts-1",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
{
|
||||
"id": "tts-1-hd",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
{
|
||||
"id": "kokoro",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
}
|
||||
]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to retrieve model list",
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
@router.get("/models/{model}")
|
||||
async def retrieve_model(model: str):
|
||||
"""Retrieve a specific model"""
|
||||
try:
|
||||
# Define available models
|
||||
models = {
|
||||
"tts-1": {
|
||||
"id": "tts-1",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
"tts-1-hd": {
|
||||
"id": "tts-1-hd",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
},
|
||||
"kokoro": {
|
||||
"id": "kokoro",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "kokoro"
|
||||
}
|
||||
}
|
||||
|
||||
# Check if requested model exists
|
||||
if model not in models:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "model_not_found",
|
||||
"message": f"Model '{model}' not found",
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
|
||||
# Return the specific model
|
||||
return models[model]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving model {model}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to retrieve model information",
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices():
|
||||
"""List all available voices for text-to-speech"""
|
||||
|
|
|
@ -69,6 +69,49 @@ def test_load_openai_mappings_file_not_found():
|
|||
assert mappings == {"models": {}, "voices": {}}
|
||||
|
||||
|
||||
def test_list_models(mock_openai_mappings):
|
||||
"""Test listing available models endpoint"""
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert isinstance(data["data"], list)
|
||||
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
|
||||
|
||||
# Verify all expected models are present
|
||||
model_ids = [model["id"] for model in data["data"]]
|
||||
assert "tts-1" in model_ids
|
||||
assert "tts-1-hd" in model_ids
|
||||
assert "kokoro" in model_ids
|
||||
|
||||
# Verify model format
|
||||
for model in data["data"]:
|
||||
assert model["object"] == "model"
|
||||
assert "created" in model
|
||||
assert model["owned_by"] == "kokoro"
|
||||
|
||||
|
||||
def test_retrieve_model(mock_openai_mappings):
|
||||
"""Test retrieving a specific model endpoint"""
|
||||
# Test successful model retrieval
|
||||
response = client.get("/v1/models/tts-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == "tts-1"
|
||||
assert data["object"] == "model"
|
||||
assert data["owned_by"] == "kokoro"
|
||||
assert "created" in data
|
||||
|
||||
# Test non-existent model
|
||||
response = client.get("/v1/models/nonexistent-model")
|
||||
assert response.status_code == 404
|
||||
error = response.json()
|
||||
assert error["detail"]["error"] == "model_not_found"
|
||||
assert "not found" in error["detail"]["message"]
|
||||
assert error["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tts_service_initialization():
|
||||
"""Test TTSService initialization"""
|
||||
|
|
12
debug.http
12
debug.http
|
@ -15,3 +15,15 @@ Accept: application/json
|
|||
# Useful for debugging resource exhaustion issues
|
||||
GET http://localhost:8880/debug/session_pools
|
||||
Accept: application/json
|
||||
|
||||
### List Available Models
|
||||
# Returns list of all available models in OpenAI format
|
||||
# Response includes tts-1, tts-1-hd, and kokoro models
|
||||
GET http://localhost:8880/v1/models
|
||||
Accept: application/json
|
||||
|
||||
### Get Specific Model
|
||||
# Returns same model list as above for compatibility
|
||||
# Works with any model name (e.g., tts-1, tts-1-hd, kokoro)
|
||||
GET http://localhost:8880/v1/models/tts-1
|
||||
Accept: application/json
|
Loading…
Add table
Reference in a new issue