mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Mostly completed work on refractoring a bunch of code as well as streaming word level time stamps
This commit is contained in:
parent
0b5ec320c7
commit
34acb17682
7 changed files with 87 additions and 84 deletions
62
README.md
62
README.md
|
@ -367,9 +367,10 @@ The model is capable of processing up to a 510 phonemized token chunk at a time,
|
|||
<details>
|
||||
<summary>Timestamped Captions & Phonemes</summary>
|
||||
|
||||
Generate audio with word-level timestamps:
|
||||
Generate audio with word-level timestamps without streaming:
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
|
||||
response = requests.post(
|
||||
|
@ -379,19 +380,58 @@ response = requests.post(
|
|||
"input": "Hello world!",
|
||||
"voice": "af_bella",
|
||||
"speed": 1.0,
|
||||
"response_format": "wav"
|
||||
}
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Get timestamps from header
|
||||
timestamps = json.loads(response.headers['X-Word-Timestamps'])
|
||||
print("Word-level timestamps:")
|
||||
for ts in timestamps:
|
||||
print(f"{ts['word']}: {ts['start_time']:.3f}s - {ts['end_time']:.3f}s")
|
||||
with open("output.mp3","wb") as f:
|
||||
|
||||
# Save audio
|
||||
with open("output.wav", "wb") as f:
|
||||
f.write(response.content)
|
||||
audio_json=json.loads(response.content)
|
||||
|
||||
# Decode base 64 stream to bytes
|
||||
chunk_audio=base64.b64decode(audio_json["audio"].encode("utf-8"))
|
||||
|
||||
# Process streaming chunks
|
||||
f.write(chunk_audio)
|
||||
|
||||
# Print word level timestamps
|
||||
print(audio_json["timestamps"])
|
||||
```
|
||||
|
||||
Generate audio with word-level timestamps with streaming:
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/captioned_speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world!",
|
||||
"voice": "af_bella",
|
||||
"speed": 1.0,
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
f=open("output.mp3","wb")
|
||||
for chunk in response.iter_lines(decode_unicode=True):
|
||||
if chunk:
|
||||
chunk_json=json.loads(chunk)
|
||||
|
||||
# Decode base 64 stream to bytes
|
||||
chunk_audio=base64.b64decode(chunk_json["audio"].encode("utf-8"))
|
||||
|
||||
# Process streaming chunks
|
||||
f.write(chunk_audio)
|
||||
|
||||
# Print word level timestamps
|
||||
print(chunk_json["timestamps"])
|
||||
```
|
||||
</details>
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class AudioChunk:
|
|||
for audio_chunk in audio_chunk_list[1:]:
|
||||
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
|
||||
if output.word_timestamps is not None:
|
||||
output.word_timestamps+=output.word_timestamps
|
||||
output.word_timestamps+=audio_chunk.word_timestamps
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import re
|
||||
from typing import List, Union, AsyncGenerator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
|
@ -156,40 +157,6 @@ async def generate_from_phonemes(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dev/timestamps/{filename}")
|
||||
async def get_timestamps(filename: str):
|
||||
"""Download timestamps from temp storage"""
|
||||
try:
|
||||
from ..core.paths import _find_file
|
||||
|
||||
# Search for file in temp directory
|
||||
file_path = await _find_file(
|
||||
filename=filename, search_paths=[settings.temp_file_dir]
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
file_path,
|
||||
media_type="application/json",
|
||||
filename=filename,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving timestamps file {filename}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to serve timestamps file",
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/dev/captioned_speech")
|
||||
async def create_captioned_speech(
|
||||
request: CaptionedSpeechRequest,
|
||||
|
@ -245,8 +212,9 @@ async def create_captioned_speech(
|
|||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||
|
||||
yield chunk
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
||||
|
||||
# Finalize the temp file
|
||||
await temp_writer.finalize()
|
||||
|
@ -272,13 +240,11 @@ async def create_captioned_speech(
|
|||
# Encode the chunk bytes into base 64
|
||||
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps)
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
raise
|
||||
|
||||
# NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS
|
||||
|
||||
# Standard streaming without download link
|
||||
return JSONStreamingResponse(
|
||||
single_output(),
|
||||
|
@ -296,6 +262,8 @@ async def create_captioned_speech(
|
|||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
return_timestamps=request.return_timestamps,
|
||||
normalization_options=request.normalization_options,
|
||||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
|
@ -316,9 +284,13 @@ async def create_captioned_speech(
|
|||
is_last_chunk=True,
|
||||
)
|
||||
output=content+final
|
||||
return Response(
|
||||
content=output,
|
||||
media_type=content_type,
|
||||
|
||||
base64_output= base64.b64encode(output).decode("utf-8")
|
||||
|
||||
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
||||
return JSONResponse(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
|
|
|
@ -282,8 +282,10 @@ async def create_speech(
|
|||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
normalization_options=request.normalization_options,
|
||||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
content, audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
|
|
|
@ -333,20 +333,17 @@ class TTSService:
|
|||
voice: str,
|
||||
speed: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
lang_code: Optional[str] = None,
|
||||
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
audio_data_chunks=[]
|
||||
|
||||
try:
|
||||
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
|
||||
|
||||
|
||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data.audio,combined_audio_data
|
||||
"""
|
||||
|
|
|
@ -35,7 +35,7 @@ class CaptionedSpeechResponse(BaseModel):
|
|||
|
||||
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
||||
audio_format: str = Field(..., description="The format of the output audio")
|
||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||
timestamps: Optional[List[WordTimestamp]] = Field(..., description="Word-level timestamps")
|
||||
|
||||
class NormalizationOptions(BaseModel):
|
||||
"""Options for the normalization system"""
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
from typing import Tuple, Optional, Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
import base64
|
||||
import requests
|
||||
|
||||
# Get the directory this script is in
|
||||
|
@ -9,9 +10,9 @@ SCRIPT_DIR = Path(__file__).absolute().parent
|
|||
|
||||
def generate_captioned_speech(
|
||||
text: str,
|
||||
voice: str = "af_bella",
|
||||
voice: str = "af_heart",
|
||||
speed: float = 1.0,
|
||||
response_format: str = "wav"
|
||||
response_format: str = "mp3"
|
||||
) -> Tuple[Optional[bytes], Optional[List[Dict]]]:
|
||||
"""Generate audio with word-level timestamps."""
|
||||
response = requests.post(
|
||||
|
@ -21,40 +22,31 @@ def generate_captioned_speech(
|
|||
"input": text,
|
||||
"voice": voice,
|
||||
"speed": speed,
|
||||
"response_format": response_format
|
||||
"response_format": response_format,
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Response headers: {dict(response.headers)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Error response: {response.text}")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# Get timestamps path from header
|
||||
timestamps_filename = response.headers.get('X-Timestamps-Path')
|
||||
if not timestamps_filename:
|
||||
print("Error: No timestamps path in response headers")
|
||||
return None, None
|
||||
audio_json=json.loads(response.content)
|
||||
|
||||
# Get timestamps from the path
|
||||
timestamps_response = requests.get(f"http://localhost:8880/dev/timestamps/{timestamps_filename}")
|
||||
if timestamps_response.status_code != 200:
|
||||
print(f"Error getting timestamps: {timestamps_response.text}")
|
||||
return None, None
|
||||
# Decode base 64 stream to bytes
|
||||
chunk_audio=base64.b64decode(audio_json["audio"].encode("utf-8"))
|
||||
|
||||
word_timestamps = timestamps_response.json()
|
||||
# Print word level timestamps
|
||||
print(audio_json["timestamps"])
|
||||
|
||||
# Get audio bytes from content
|
||||
audio_bytes = response.content
|
||||
|
||||
if not audio_bytes:
|
||||
if not chunk_audio:
|
||||
print("Error: Empty audio content")
|
||||
return None, None
|
||||
|
||||
return audio_bytes, word_timestamps
|
||||
return chunk_audio, audio_json["timestamps"]
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing timestamps: {e}")
|
||||
return None, None
|
||||
|
|
Loading…
Add table
Reference in a new issue