mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Update handling in generate_captioned_speech to stream immediately, templink for caption file, and add unit tests for captioned speech generation
This commit is contained in:
parent
a91e0fe9df
commit
d73ed87987
5 changed files with 202 additions and 59 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
v0.2.0
|
v0.2.1
|
||||||
|
|
|
@ -1,22 +1,28 @@
|
||||||
from typing import List
|
from typing import List, Union, AsyncGenerator, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
from ..services.audio import AudioNormalizer, AudioService
|
from ..services.audio import AudioNormalizer, AudioService
|
||||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
from ..services.text_processing import smart_split
|
from ..services.text_processing import smart_split
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
|
from ..services.temp_manager import TempFileWriter
|
||||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
GenerateFromPhonemesRequest,
|
GenerateFromPhonemesRequest,
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
)
|
)
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
@ -147,86 +153,169 @@ async def generate_from_phonemes(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dev/timestamps/{filename}")
|
||||||
|
async def get_timestamps(filename: str):
|
||||||
|
"""Download timestamps from temp storage"""
|
||||||
|
try:
|
||||||
|
from ..core.paths import _find_file
|
||||||
|
|
||||||
|
# Search for file in temp directory
|
||||||
|
file_path = await _find_file(
|
||||||
|
filename=filename, search_paths=[settings.temp_file_dir]
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
file_path,
|
||||||
|
media_type="application/json",
|
||||||
|
filename=filename,
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Content-Disposition": f"attachment; filename={filename}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error serving timestamps file {filename}: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "Failed to serve timestamps file",
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/dev/captioned_speech")
|
@router.post("/dev/captioned_speech")
|
||||||
async def create_captioned_speech(
|
async def create_captioned_speech(
|
||||||
request: CaptionedSpeechRequest,
|
request: CaptionedSpeechRequest,
|
||||||
|
client_request: Request,
|
||||||
|
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
) -> StreamingResponse:
|
):
|
||||||
"""Generate audio with word-level timestamps using Kokoro's output"""
|
"""Generate audio with word-level timestamps using streaming approach"""
|
||||||
try:
|
try:
|
||||||
# Get voice path
|
# Set content type based on format
|
||||||
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
content_type = {
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"opus": "audio/opus",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"pcm": "audio/pcm",
|
||||||
|
}.get(request.response_format, f"audio/{request.response_format}")
|
||||||
|
|
||||||
# Generate audio with timestamps
|
# Create streaming audio writer and normalizer
|
||||||
audio, _, word_timestamps = await tts_service.generate_audio(
|
|
||||||
text=request.input,
|
|
||||||
voice=voice_name,
|
|
||||||
speed=request.speed,
|
|
||||||
return_timestamps=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create streaming audio writer
|
|
||||||
writer = StreamingAudioWriter(
|
writer = StreamingAudioWriter(
|
||||||
format=request.response_format, sample_rate=24000, channels=1
|
format=request.response_format, sample_rate=24000, channels=1
|
||||||
)
|
)
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
|
# Get voice path
|
||||||
|
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||||
|
|
||||||
|
# Use provided lang_code or determine from voice name
|
||||||
|
pipeline_lang_code = request.lang_code if request.lang_code else request.voice[0].lower()
|
||||||
|
logger.info(
|
||||||
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get backend and pipeline
|
||||||
|
backend = tts_service.model_manager.get_backend()
|
||||||
|
pipeline = backend._get_pipeline(pipeline_lang_code)
|
||||||
|
|
||||||
|
# Create temp file writer for timestamps
|
||||||
|
temp_writer = TempFileWriter("json")
|
||||||
|
await temp_writer.__aenter__() # Initialize temp file
|
||||||
|
# Get just the filename without the path
|
||||||
|
timestamps_filename = Path(temp_writer.download_path).name
|
||||||
|
|
||||||
|
# Initialize variables for timestamps
|
||||||
|
word_timestamps = []
|
||||||
|
current_offset = 0.0
|
||||||
|
|
||||||
async def generate_chunks():
|
async def generate_chunks():
|
||||||
|
nonlocal current_offset, word_timestamps
|
||||||
try:
|
try:
|
||||||
if audio is not None:
|
# Process text in chunks with smart splitting
|
||||||
# Normalize audio before writing
|
async for chunk_text, tokens in smart_split(request.input):
|
||||||
normalized_audio = await normalizer.normalize(audio)
|
# Process chunk with pipeline
|
||||||
# Write chunk and yield bytes
|
for result in pipeline(chunk_text, voice=voice_path, speed=request.speed):
|
||||||
|
if result.audio is not None:
|
||||||
|
# Process timestamps for this chunk
|
||||||
|
if hasattr(result, "tokens") and result.tokens and result.pred_dur is not None:
|
||||||
|
try:
|
||||||
|
# Join timestamps for this chunk's tokens
|
||||||
|
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||||
|
|
||||||
|
# Add timestamps with offset
|
||||||
|
for token in result.tokens:
|
||||||
|
if not all(
|
||||||
|
hasattr(token, attr)
|
||||||
|
for attr in ["text", "start_ts", "end_ts"]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if not token.text or not token.text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply offset to timestamps
|
||||||
|
start_time = float(token.start_ts) + current_offset
|
||||||
|
end_time = float(token.end_ts) + current_offset
|
||||||
|
|
||||||
|
word_timestamps.append(
|
||||||
|
{
|
||||||
|
"word": str(token.text).strip(),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update offset for next chunk
|
||||||
|
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
|
||||||
|
current_offset = max(current_offset + chunk_duration, end_time)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process timestamps for chunk: {e}")
|
||||||
|
|
||||||
|
# Process audio
|
||||||
|
audio_chunk = result.audio.numpy()
|
||||||
|
normalized_audio = await normalizer.normalize(audio_chunk)
|
||||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||||
if chunk_bytes:
|
if chunk_bytes:
|
||||||
yield chunk_bytes
|
yield chunk_bytes
|
||||||
|
|
||||||
# Finalize and yield remaining bytes
|
# Write timestamps to temp file
|
||||||
|
timestamps_json = json.dumps(word_timestamps)
|
||||||
|
await temp_writer.write(timestamps_json.encode())
|
||||||
|
await temp_writer.finalize()
|
||||||
|
|
||||||
|
# Finalize audio
|
||||||
final_bytes = writer.write_chunk(finalize=True)
|
final_bytes = writer.write_chunk(finalize=True)
|
||||||
if final_bytes:
|
if final_bytes:
|
||||||
yield final_bytes
|
yield final_bytes
|
||||||
else:
|
|
||||||
raise ValueError("Failed to generate audio data")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
# Clean up writer on error
|
# Clean up writer on error
|
||||||
writer.write_chunk(finalize=True)
|
writer.write_chunk(finalize=True)
|
||||||
|
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||||
# Re-raise the original exception
|
# Re-raise the original exception
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert timestamps to JSON and add as header
|
|
||||||
import json
|
|
||||||
|
|
||||||
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
|
|
||||||
timestamps_json = json.dumps(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"word": str(ts["word"]), # Ensure string for text
|
|
||||||
"start_time": float(
|
|
||||||
ts["start_time"]
|
|
||||||
), # Ensure float for timestamps
|
|
||||||
"end_time": float(ts["end_time"]),
|
|
||||||
}
|
|
||||||
for ts in word_timestamps
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_chunks(),
|
generate_chunks(),
|
||||||
media_type=f"audio/{request.response_format}",
|
media_type=content_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Transfer-Encoding": "chunked",
|
"Transfer-Encoding": "chunked",
|
||||||
"X-Word-Timestamps": timestamps_json,
|
"X-Timestamps-Path": timestamps_filename,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
logger.warning(f"Invalid request: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
|
@ -235,8 +324,18 @@ async def create_captioned_speech(
|
||||||
"type": "invalid_request_error",
|
"type": "invalid_request_error",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
logger.error(f"Processing error: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
|
|
@ -137,7 +137,7 @@ async def stream_audio_chunks(
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code or request.voice[0],
|
||||||
):
|
):
|
||||||
# Check if client is still connected
|
# Check if client is still connected
|
||||||
is_disconnected = client_request.is_disconnected
|
is_disconnected = client_request.is_disconnected
|
||||||
|
|
31
api/tests/test_development.py
Normal file
31
api/tests/test_development.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def test_generate_captioned_speech():
|
||||||
|
"""Test the generate_captioned_speech function with mocked responses"""
|
||||||
|
# Mock the API responses
|
||||||
|
mock_audio_response = MagicMock()
|
||||||
|
mock_audio_response.status_code = 200
|
||||||
|
mock_audio_response.content = b"mock audio data"
|
||||||
|
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
|
||||||
|
|
||||||
|
mock_timestamps_response = MagicMock()
|
||||||
|
mock_timestamps_response.status_code = 200
|
||||||
|
mock_timestamps_response.json.return_value = [
|
||||||
|
{"word": "test", "start_time": 0.0, "end_time": 1.0}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Patch both HTTP requests
|
||||||
|
with patch('requests.post', return_value=mock_audio_response), \
|
||||||
|
patch('requests.get', return_value=mock_timestamps_response):
|
||||||
|
|
||||||
|
# Import here to avoid module-level import issues
|
||||||
|
from examples.captioned_speech_example import generate_captioned_speech
|
||||||
|
|
||||||
|
# Test the function
|
||||||
|
audio, timestamps = generate_captioned_speech("test text")
|
||||||
|
|
||||||
|
# Verify we got both audio and timestamps
|
||||||
|
assert audio == b"mock audio data"
|
||||||
|
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
|
@ -33,9 +33,19 @@ def generate_captioned_speech(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get timestamps from header
|
# Get timestamps path from header
|
||||||
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
|
timestamps_filename = response.headers.get('X-Timestamps-Path')
|
||||||
word_timestamps = json.loads(timestamps_json)
|
if not timestamps_filename:
|
||||||
|
print("Error: No timestamps path in response headers")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Get timestamps from the path
|
||||||
|
timestamps_response = requests.get(f"http://localhost:8880/dev/timestamps/{timestamps_filename}")
|
||||||
|
if timestamps_response.status_code != 200:
|
||||||
|
print(f"Error getting timestamps: {timestamps_response.text}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
word_timestamps = timestamps_response.json()
|
||||||
|
|
||||||
# Get audio bytes from content
|
# Get audio bytes from content
|
||||||
audio_bytes = response.content
|
audio_bytes = response.content
|
||||||
|
@ -48,6 +58,9 @@ def generate_captioned_speech(
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"Error parsing timestamps: {e}")
|
print(f"Error parsing timestamps: {e}")
|
||||||
return None, None
|
return None, None
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error retrieving timestamps: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Example texts to convert
|
# Example texts to convert
|
||||||
|
|
Loading…
Add table
Reference in a new issue