Kokoro-FastAPI/api/tests/test_openai_endpoints_v2.py
remsky 9a588a3483 WIP: 1.0 integration
- Introduced v1.0 model build system integration.
- Updated imports to reflect new directory structure for versioned models.
- Modified environment variables
- Added version selection in the frontend for voice management.
- Enhanced Docker build scripts for multi-platform support.
- Updated configuration settings for default voice and model paths.
2025-01-31 05:55:57 -07:00

229 lines
No EOL
6.5 KiB
Python

"""Tests for OpenAI-compatible v2 endpoints."""
import pytest
from fastapi.testclient import TestClient
from loguru import logger
from ..main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
def test_health_check(client):
"""Test health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_list_versions(client):
"""Test version listing endpoint."""
response = client.get("/v2/audio/versions")
assert response.status_code == 200
data = response.json()
assert "versions" in data
assert "current" in data
assert "v0.19" in data["versions"]
assert "v1.0" in data["versions"]
def test_set_version(client):
"""Test version setting endpoint."""
# Set to v1.0
response = client.post("/v2/audio/version", json="v1.0")
assert response.status_code == 200
data = response.json()
assert data["current"] == "v1.0"
# Set back to v0.19
response = client.post("/v2/audio/version", json="v0.19")
assert response.status_code == 200
data = response.json()
assert data["current"] == "v0.19"
# Test invalid version
response = client.post("/v2/audio/version", json="invalid_version")
assert response.status_code == 400
def test_list_voices(client):
"""Test voice listing endpoint."""
response = client.get("/v2/audio/voices")
assert response.status_code == 200
data = response.json()
assert "voices" in data
assert len(data["voices"]) > 0
def test_combine_voices(client):
"""Test voice combination endpoint."""
# Test with string input
response = client.post("/v2/audio/voices/combine", json="af_bella+af_nicole")
assert response.status_code == 200
data = response.json()
assert "voice" in data
assert "voices" in data
# Test with list input
response = client.post("/v2/audio/voices/combine", json=["af_bella", "af_nicole"])
assert response.status_code == 200
data = response.json()
assert "voice" in data
assert "voices" in data
def test_speech_generation_v0_19(client):
"""Test speech generation with v0.19."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"speed": 1.0,
"stream": False,
"version": "v0.19"
}
response = client.post("/v2/audio/speech", json=request_data)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert len(response.content) > 0
def test_speech_generation_v1_0(client):
"""Test speech generation with v1.0."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"speed": 1.0,
"stream": False,
"version": "v1.0"
}
response = client.post("/v2/audio/speech", json=request_data)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert len(response.content) > 0
def test_streaming_speech_v0_19(client):
"""Test streaming speech generation with v0.19."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"speed": 1.0,
"stream": True,
"version": "v0.19"
}
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
assert response.status_code == 200
content = b""
for chunk in response.iter_bytes():
assert len(chunk) > 0
content += chunk
assert len(content) > 0
def test_streaming_speech_v1_0(client):
"""Test streaming speech generation with v1.0."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"speed": 1.0,
"stream": True,
"version": "v1.0"
}
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
assert response.status_code == 200
content = b""
for chunk in response.iter_bytes():
assert len(chunk) > 0
content += chunk
assert len(content) > 0
def test_invalid_model(client):
"""Test invalid model handling."""
request_data = {
"model": "invalid-model",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"version": "v1.0"
}
response = client.post("/v2/audio/speech", json=request_data)
assert response.status_code == 400
data = response.json()
assert "error" in data
assert data["error"] == "invalid_model"
def test_invalid_voice(client):
"""Test invalid voice handling."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "invalid_voice",
"response_format": "wav",
"version": "v1.0"
}
response = client.post("/v2/audio/speech", json=request_data)
assert response.status_code == 400
data = response.json()
assert "error" in data
assert data["error"] == "validation_error"
def test_invalid_version(client):
"""Test invalid version handling."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"version": "invalid_version"
}
response = client.post("/v2/audio/speech", json=request_data)
assert response.status_code == 400
data = response.json()
assert "error" in data
assert data["error"] == "validation_error"
def test_download_link(client):
"""Test download link functionality."""
request_data = {
"model": "tts-1",
"input": "Hello, world!",
"voice": "af_bella",
"response_format": "wav",
"speed": 1.0,
"stream": True,
"return_download_link": True,
"version": "v1.0"
}
with client.stream("POST", "/v2/audio/speech", json=request_data) as response:
assert response.status_code == 200
assert "X-Download-Path" in response.headers
download_path = response.headers["X-Download-Path"]
# Try downloading the file
download_response = client.get(f"/download/{download_path}")
assert download_response.status_code == 200
assert len(download_response.content) > 0