diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index b42c794..9cb7370 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,8 +1,10 @@ -from loguru import logger -from fastapi import Depends, Response, APIRouter, HTTPException +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Response +from loguru import logger -from ..services.tts import TTSService from ..services.audio import AudioService +from ..services.tts import TTSService from ..structures.schemas import OpenAISpeechRequest router = APIRouter( @@ -57,3 +59,14 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)): except Exception as e: logger.error(f"Error listing voices: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/audio/voices/combine") +async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): + try: + t = tts_service.combine_voices(voices=request) + voices = tts_service.list_voices() + return {"voices": voices, "voice": t} + except Exception as e: + logger.error(f"Error listing voices: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/src/services/tts.py b/api/src/services/tts.py index 7224b0e..f3a24f0 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -1,16 +1,17 @@ import io import os import re -import time import threading +import time from typing import List, Tuple import numpy as np -import torch -import tiktoken import scipy.io.wavfile as wavfile -from kokoro import generate, tokenize, phonemize, normalize_text +import tiktoken +import torch from loguru import logger + +from kokoro import generate, normalize_text, phonemize, tokenize from models import build_model from ..core.config import settings @@ -93,7 +94,7 @@ class TTSService: ps = phonemize(chunk, voice[0]) tokens = tokenize(ps) logger.debug( - f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens" + f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens" ) # Only proceed if phonemization succeeded @@ -104,11 +105,11 @@ class TTSService: audio_chunks.append(chunk_audio) else: logger.error( - f"No audio generated for chunk {i+1}/{len(chunks)}" + f"No audio generated for chunk {i + 1}/{len(chunks)}" ) except Exception as e: logger.error( - f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}" + f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}" ) continue @@ -141,7 +142,27 @@ class TTSService: wavfile.write(buffer, 24000, audio) return buffer.getvalue() - def list_voices(self) -> list[str]: + def combine_voices(self, voices: List[str]) -> str: + if len(voices) < 2: + return "af" + t_voices: List[torch.Tensor] = [] + v_name: List[str] = [] + try: + for file in os.listdir("voices"): + voice_name = file[:-3] # Remove .pt extension + for n in voices: + if n == voice_name: + v_name.append(voice_name) + t_voices.append(torch.load(f"voices/{file}", weights_only=True)) + except Exception as e: + print(f"Error combining voices: {str(e)}") + return "af" + f: str = "_".join(v_name) + v = torch.mean(torch.stack(t_voices), dim=0) + torch.save(v, f"voices/{f}.pt") + return f + + def list_voices(self) -> List[str]: """List all available voices""" voices = [] try: diff --git a/examples/test_combine_voices.py b/examples/test_combine_voices.py new file mode 100644 index 0000000..993d8b5 --- /dev/null +++ b/examples/test_combine_voices.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +import argparse +from typing import List, Optional + +import requests + + +def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]: + try: + response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices) + if response.status_code != 200: + print(f"Error submitting request: {response.text}") + return None + return response.json()["voices"] + except requests.exceptions.RequestException as e: + print(f"Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Kokoro TTS CLI") + parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") + parser.add_argument("--url", default="http://localhost:8880", help="API base URL") + args = parser.parse_args() + + success = submit_combine_voices(args.voices, args.url) + if success: + for voice in success: + print(f" {voice}") + + +if __name__ == "__main__": + main()