mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
commit
decf9123e7
3 changed files with 78 additions and 11 deletions
|
@ -1,8 +1,10 @@
|
||||||
from loguru import logger
|
from typing import List
|
||||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.tts import TTSService
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
from ..services.tts import TTSService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
|
@ -57,3 +59,14 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing voices: {str(e)}")
|
logger.error(f"Error listing voices: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/audio/voices/combine")
|
||||||
|
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
|
||||||
|
try:
|
||||||
|
t = tts_service.combine_voices(voices=request)
|
||||||
|
voices = tts_service.list_voices()
|
||||||
|
return {"voices": voices, "voice": t}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing voices: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import tiktoken
|
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
from kokoro import generate, tokenize, phonemize, normalize_text
|
import tiktoken
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from kokoro import generate, normalize_text, phonemize, tokenize
|
||||||
from models import build_model
|
from models import build_model
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
@ -93,7 +94,7 @@ class TTSService:
|
||||||
ps = phonemize(chunk, voice[0])
|
ps = phonemize(chunk, voice[0])
|
||||||
tokens = tokenize(ps)
|
tokens = tokenize(ps)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens"
|
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only proceed if phonemization succeeded
|
# Only proceed if phonemization succeeded
|
||||||
|
@ -104,11 +105,11 @@ class TTSService:
|
||||||
audio_chunks.append(chunk_audio)
|
audio_chunks.append(chunk_audio)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"No audio generated for chunk {i+1}/{len(chunks)}"
|
f"No audio generated for chunk {i + 1}/{len(chunks)}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -141,7 +142,27 @@ class TTSService:
|
||||||
wavfile.write(buffer, 24000, audio)
|
wavfile.write(buffer, 24000, audio)
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
def list_voices(self) -> list[str]:
|
def combine_voices(self, voices: List[str]) -> str:
|
||||||
|
if len(voices) < 2:
|
||||||
|
return "af"
|
||||||
|
t_voices: List[torch.Tensor] = []
|
||||||
|
v_name: List[str] = []
|
||||||
|
try:
|
||||||
|
for file in os.listdir("voices"):
|
||||||
|
voice_name = file[:-3] # Remove .pt extension
|
||||||
|
for n in voices:
|
||||||
|
if n == voice_name:
|
||||||
|
v_name.append(voice_name)
|
||||||
|
t_voices.append(torch.load(f"voices/{file}", weights_only=True))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error combining voices: {str(e)}")
|
||||||
|
return "af"
|
||||||
|
f: str = "_".join(v_name)
|
||||||
|
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||||
|
torch.save(v, f"voices/{f}.pt")
|
||||||
|
return f
|
||||||
|
|
||||||
|
def list_voices(self) -> List[str]:
|
||||||
"""List all available voices"""
|
"""List all available voices"""
|
||||||
voices = []
|
voices = []
|
||||||
try:
|
try:
|
||||||
|
|
33
examples/test_combine_voices.py
Normal file
33
examples/test_combine_voices.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]:
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error submitting request: {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()["voices"]
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
|
||||||
|
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||||
|
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
success = submit_combine_voices(args.voices, args.url)
|
||||||
|
if success:
|
||||||
|
for voice in success:
|
||||||
|
print(f" {voice}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Reference in a new issue