Merge pull request #2 from eschmidbauer/master

Looks great
This commit is contained in:
remsky 2024-12-31 19:02:21 -07:00 committed by GitHub
commit decf9123e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 11 deletions

View file

@ -1,8 +1,10 @@
from loguru import logger from typing import List
from fastapi import Depends, Response, APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
from ..services.tts import TTSService
from ..services.audio import AudioService from ..services.audio import AudioService
from ..services.tts import TTSService
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
router = APIRouter( router = APIRouter(
@ -57,3 +59,14 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
except Exception as e: except Exception as e:
logger.error(f"Error listing voices: {str(e)}") logger.error(f"Error listing voices: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
try:
t = tts_service.combine_voices(voices=request)
voices = tts_service.list_voices()
return {"voices": voices, "voice": t}
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

View file

@ -1,16 +1,17 @@
import io import io
import os import os
import re import re
import time
import threading import threading
import time
from typing import List, Tuple from typing import List, Tuple
import numpy as np import numpy as np
import torch
import tiktoken
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text import tiktoken
import torch
from loguru import logger from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model from models import build_model
from ..core.config import settings from ..core.config import settings
@ -93,7 +94,7 @@ class TTSService:
ps = phonemize(chunk, voice[0]) ps = phonemize(chunk, voice[0])
tokens = tokenize(ps) tokens = tokenize(ps)
logger.debug( logger.debug(
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens" f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
) )
# Only proceed if phonemization succeeded # Only proceed if phonemization succeeded
@ -104,11 +105,11 @@ class TTSService:
audio_chunks.append(chunk_audio) audio_chunks.append(chunk_audio)
else: else:
logger.error( logger.error(
f"No audio generated for chunk {i+1}/{len(chunks)}" f"No audio generated for chunk {i + 1}/{len(chunks)}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}" f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
) )
continue continue
@ -141,7 +142,27 @@ class TTSService:
wavfile.write(buffer, 24000, audio) wavfile.write(buffer, 24000, audio)
return buffer.getvalue() return buffer.getvalue()
def list_voices(self) -> list[str]: def combine_voices(self, voices: List[str]) -> str:
if len(voices) < 2:
return "af"
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
try:
for file in os.listdir("voices"):
voice_name = file[:-3] # Remove .pt extension
for n in voices:
if n == voice_name:
v_name.append(voice_name)
t_voices.append(torch.load(f"voices/{file}", weights_only=True))
except Exception as e:
print(f"Error combining voices: {str(e)}")
return "af"
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
torch.save(v, f"voices/{f}.pt")
return f
def list_voices(self) -> List[str]:
"""List all available voices""" """List all available voices"""
voices = [] voices = []
try: try:

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import argparse
from typing import List, Optional
import requests
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]:
try:
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
if response.status_code != 200:
print(f"Error submitting request: {response.text}")
return None
return response.json()["voices"]
except requests.exceptions.RequestException as e:
print(f"Error: {e}")
return None
def main():
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
args = parser.parse_args()
success = submit_combine_voices(args.voices, args.url)
if success:
for voice in success:
print(f" {voice}")
if __name__ == "__main__":
main()