mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00

- Introduced v1.0 model build system integration. - Updated imports to reflect new directory structure for versioned models. - Modified environment variables - Added version selection in the frontend for voice management. - Enhanced Docker build scripts for multi-platform support. - Updated configuration settings for default voice and model paths.
229 lines
No EOL
6.5 KiB
Python
229 lines
No EOL
6.5 KiB
Python
"""Tests for OpenAI-compatible v2 endpoints."""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from loguru import logger
|
|
|
|
from ..main import app
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Test client fixture."""
|
|
return TestClient(app)
|
|
|
|
|
|
def test_health_check(client):
|
|
"""Test health check endpoint."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
|
def test_list_versions(client):
|
|
"""Test version listing endpoint."""
|
|
response = client.get("/v2/audio/versions")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "versions" in data
|
|
assert "current" in data
|
|
assert "v0.19" in data["versions"]
|
|
assert "v1.0" in data["versions"]
|
|
|
|
|
|
def test_set_version(client):
|
|
"""Test version setting endpoint."""
|
|
# Set to v1.0
|
|
response = client.post("/v2/audio/version", json="v1.0")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["current"] == "v1.0"
|
|
|
|
# Set back to v0.19
|
|
response = client.post("/v2/audio/version", json="v0.19")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["current"] == "v0.19"
|
|
|
|
# Test invalid version
|
|
response = client.post("/v2/audio/version", json="invalid_version")
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_list_voices(client):
|
|
"""Test voice listing endpoint."""
|
|
response = client.get("/v2/audio/voices")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "voices" in data
|
|
assert len(data["voices"]) > 0
|
|
|
|
|
|
def test_combine_voices(client):
|
|
"""Test voice combination endpoint."""
|
|
# Test with string input
|
|
response = client.post("/v2/audio/voices/combine", json="af_bella+af_nicole")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "voice" in data
|
|
assert "voices" in data
|
|
|
|
# Test with list input
|
|
response = client.post("/v2/audio/voices/combine", json=["af_bella", "af_nicole"])
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "voice" in data
|
|
assert "voices" in data
|
|
|
|
|
|
def test_speech_generation_v0_19(client):
|
|
"""Test speech generation with v0.19."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"speed": 1.0,
|
|
"stream": False,
|
|
"version": "v0.19"
|
|
}
|
|
|
|
response = client.post("/v2/audio/speech", json=request_data)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "audio/wav"
|
|
assert len(response.content) > 0
|
|
|
|
|
|
def test_speech_generation_v1_0(client):
|
|
"""Test speech generation with v1.0."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"speed": 1.0,
|
|
"stream": False,
|
|
"version": "v1.0"
|
|
}
|
|
|
|
response = client.post("/v2/audio/speech", json=request_data)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "audio/wav"
|
|
assert len(response.content) > 0
|
|
|
|
|
|
def test_streaming_speech_v0_19(client):
|
|
"""Test streaming speech generation with v0.19."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"speed": 1.0,
|
|
"stream": True,
|
|
"version": "v0.19"
|
|
}
|
|
|
|
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
|
|
assert response.status_code == 200
|
|
content = b""
|
|
for chunk in response.iter_bytes():
|
|
assert len(chunk) > 0
|
|
content += chunk
|
|
assert len(content) > 0
|
|
|
|
|
|
def test_streaming_speech_v1_0(client):
|
|
"""Test streaming speech generation with v1.0."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"speed": 1.0,
|
|
"stream": True,
|
|
"version": "v1.0"
|
|
}
|
|
|
|
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
|
|
assert response.status_code == 200
|
|
content = b""
|
|
for chunk in response.iter_bytes():
|
|
assert len(chunk) > 0
|
|
content += chunk
|
|
assert len(content) > 0
|
|
|
|
|
|
def test_invalid_model(client):
|
|
"""Test invalid model handling."""
|
|
request_data = {
|
|
"model": "invalid-model",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"version": "v1.0"
|
|
}
|
|
|
|
response = client.post("/v2/audio/speech", json=request_data)
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert data["error"] == "invalid_model"
|
|
|
|
|
|
def test_invalid_voice(client):
|
|
"""Test invalid voice handling."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "invalid_voice",
|
|
"response_format": "wav",
|
|
"version": "v1.0"
|
|
}
|
|
|
|
response = client.post("/v2/audio/speech", json=request_data)
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert data["error"] == "validation_error"
|
|
|
|
|
|
def test_invalid_version(client):
|
|
"""Test invalid version handling."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"version": "invalid_version"
|
|
}
|
|
|
|
response = client.post("/v2/audio/speech", json=request_data)
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert data["error"] == "validation_error"
|
|
|
|
|
|
def test_download_link(client):
|
|
"""Test download link functionality."""
|
|
request_data = {
|
|
"model": "tts-1",
|
|
"input": "Hello, world!",
|
|
"voice": "af_bella",
|
|
"response_format": "wav",
|
|
"speed": 1.0,
|
|
"stream": True,
|
|
"return_download_link": True,
|
|
"version": "v1.0"
|
|
}
|
|
|
|
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
|
|
assert response.status_code == 200
|
|
assert "X-Download-Path" in response.headers
|
|
download_path = response.headers["X-Download-Path"]
|
|
|
|
# Try downloading the file
|
|
download_response = client.get(f"/download/{download_path}")
|
|
assert download_response.status_code == 200
|
|
assert len(download_response.content) > 0 |