Kokoro-FastAPI/api/src/routers/development.py

131 lines
4 KiB
Python
Raw Normal View History

from typing import List
2025-01-09 18:41:44 -07:00
import numpy as np
from loguru import logger
2025-01-09 18:41:44 -07:00
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts_model import TTSModel
2025-01-09 18:41:44 -07:00
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
PhonemeRequest,
PhonemeResponse,
GenerateFromPhonemesRequest,
)
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"])
2025-01-09 18:41:44 -07:00
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance"""
return TTSService()
2025-01-09 18:41:44 -07:00
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@router.post("/dev/phonemize", response_model=PhonemeResponse)
2025-01-09 18:41:44 -07:00
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
"""Convert text to phonemes and tokens
2025-01-09 18:41:44 -07:00
Args:
request: Request containing text and language
tts_service: Injected TTSService instance
2025-01-09 18:41:44 -07:00
Returns:
Phonemes and token IDs
"""
try:
if not request.text:
raise ValueError("Text cannot be empty")
2025-01-09 18:41:44 -07:00
# Get phonemes
phonemes = phonemize(request.text, request.language)
if not phonemes:
raise ValueError("Failed to generate phonemes")
2025-01-09 18:41:44 -07:00
# Get tokens
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
2025-01-09 18:41:44 -07:00
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException(
2025-01-09 18:41:44 -07:00
status_code=500, detail={"error": "Server error", "message": str(e)}
)
except Exception as e:
logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException(
2025-01-09 18:41:44 -07:00
status_code=500, detail={"error": "Server error", "message": str(e)}
)
2025-01-09 18:41:44 -07:00
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
@router.post("/dev/generate_from_phonemes")
async def generate_from_phonemes(
request: GenerateFromPhonemesRequest,
2025-01-09 18:41:44 -07:00
tts_service: TTSService = Depends(get_tts_service),
) -> Response:
"""Generate audio directly from phonemes
2025-01-09 18:41:44 -07:00
Args:
request: Request containing phonemes and generation parameters
tts_service: Injected TTSService instance
2025-01-09 18:41:44 -07:00
Returns:
WAV audio bytes
"""
# Validate phonemes first
if not request.phonemes:
raise HTTPException(
status_code=400,
2025-01-09 18:41:44 -07:00
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
)
2025-01-09 18:41:44 -07:00
# Validate voice exists
voice_path = tts_service._get_voice_path(request.voice)
if not voice_path:
raise HTTPException(
status_code=400,
2025-01-09 18:41:44 -07:00
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
)
2025-01-09 18:41:44 -07:00
try:
# Load voice
voicepack = tts_service._load_voice(voice_path)
2025-01-09 18:41:44 -07:00
# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
2025-01-09 18:41:44 -07:00
# Generate audio directly from tokens
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
2025-01-09 18:41:44 -07:00
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
2025-01-09 18:41:44 -07:00
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
)
2025-01-09 18:41:44 -07:00
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=speech.wav",
"Cache-Control": "no-cache",
2025-01-09 18:41:44 -07:00
},
)
2025-01-09 18:41:44 -07:00
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
2025-01-09 18:41:44 -07:00
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
2025-01-09 18:41:44 -07:00
status_code=500, detail={"error": "Server error", "message": str(e)}
)