mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add model listing and retrieval endpoints with tests
This commit is contained in:
parent
d73ed87987
commit
8ed2f2afb6
4 changed files with 150 additions and 2 deletions
|
@ -3,8 +3,8 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||||
|
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||||
|
|
|
@ -344,6 +344,99 @@ async def download_audio_file(filename: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def list_models():
|
||||||
|
"""List all available models"""
|
||||||
|
try:
|
||||||
|
# Create standard model list
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": "tts-1",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": models
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing models: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "Failed to retrieve model list",
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/models/{model}")
|
||||||
|
async def retrieve_model(model: str):
|
||||||
|
"""Retrieve a specific model"""
|
||||||
|
try:
|
||||||
|
# Define available models
|
||||||
|
models = {
|
||||||
|
"tts-1": {
|
||||||
|
"id": "tts-1",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
},
|
||||||
|
"tts-1-hd": {
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
},
|
||||||
|
"kokoro": {
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if requested model exists
|
||||||
|
if model not in models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": "model_not_found",
|
||||||
|
"message": f"Model '{model}' not found",
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the specific model
|
||||||
|
return models[model]
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving model {model}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "Failed to retrieve model information",
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@router.get("/audio/voices")
|
@router.get("/audio/voices")
|
||||||
async def list_voices():
|
async def list_voices():
|
||||||
"""List all available voices for text-to-speech"""
|
"""List all available voices for text-to-speech"""
|
||||||
|
|
|
@ -69,6 +69,49 @@ def test_load_openai_mappings_file_not_found():
|
||||||
assert mappings == {"models": {}, "voices": {}}
|
assert mappings == {"models": {}, "voices": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models(mock_openai_mappings):
|
||||||
|
"""Test listing available models endpoint"""
|
||||||
|
response = client.get("/v1/models")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["object"] == "list"
|
||||||
|
assert isinstance(data["data"], list)
|
||||||
|
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
|
||||||
|
|
||||||
|
# Verify all expected models are present
|
||||||
|
model_ids = [model["id"] for model in data["data"]]
|
||||||
|
assert "tts-1" in model_ids
|
||||||
|
assert "tts-1-hd" in model_ids
|
||||||
|
assert "kokoro" in model_ids
|
||||||
|
|
||||||
|
# Verify model format
|
||||||
|
for model in data["data"]:
|
||||||
|
assert model["object"] == "model"
|
||||||
|
assert "created" in model
|
||||||
|
assert model["owned_by"] == "kokoro"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_model(mock_openai_mappings):
|
||||||
|
"""Test retrieving a specific model endpoint"""
|
||||||
|
# Test successful model retrieval
|
||||||
|
response = client.get("/v1/models/tts-1")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == "tts-1"
|
||||||
|
assert data["object"] == "model"
|
||||||
|
assert data["owned_by"] == "kokoro"
|
||||||
|
assert "created" in data
|
||||||
|
|
||||||
|
# Test non-existent model
|
||||||
|
response = client.get("/v1/models/nonexistent-model")
|
||||||
|
assert response.status_code == 404
|
||||||
|
error = response.json()
|
||||||
|
assert error["detail"]["error"] == "model_not_found"
|
||||||
|
assert "not found" in error["detail"]["message"]
|
||||||
|
assert error["detail"]["type"] == "invalid_request_error"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_tts_service_initialization():
|
async def test_get_tts_service_initialization():
|
||||||
"""Test TTSService initialization"""
|
"""Test TTSService initialization"""
|
||||||
|
|
12
debug.http
12
debug.http
|
@ -14,4 +14,16 @@ Accept: application/json
|
||||||
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
||||||
# Useful for debugging resource exhaustion issues
|
# Useful for debugging resource exhaustion issues
|
||||||
GET http://localhost:8880/debug/session_pools
|
GET http://localhost:8880/debug/session_pools
|
||||||
|
Accept: application/json
|
||||||
|
|
||||||
|
### List Available Models
|
||||||
|
# Returns list of all available models in OpenAI format
|
||||||
|
# Response includes tts-1, tts-1-hd, and kokoro models
|
||||||
|
GET http://localhost:8880/v1/models
|
||||||
|
Accept: application/json
|
||||||
|
|
||||||
|
### Get Specific Model
|
||||||
|
# Returns same model list as above for compatibility
|
||||||
|
# Works with any model name (e.g., tts-1, tts-1-hd, kokoro)
|
||||||
|
GET http://localhost:8880/v1/models/tts-1
|
||||||
Accept: application/json
|
Accept: application/json
|
Loading…
Add table
Reference in a new issue