mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Update handling in generate_captioned_speech to stream immediately, templink for caption file, and add unit tests for captioned speech generation
This commit is contained in:
parent
a91e0fe9df
commit
d73ed87987
5 changed files with 202 additions and 59 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
v0.2.0
|
||||
v0.2.1
|
||||
|
|
|
@ -1,22 +1,28 @@
|
|||
from typing import List
|
||||
from typing import List, Union, AsyncGenerator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..services.audio import AudioNormalizer, AudioService
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
from ..services.text_processing import smart_split
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
)
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
@ -147,86 +153,169 @@ async def generate_from_phonemes(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/dev/timestamps/{filename}")
|
||||
async def get_timestamps(filename: str):
|
||||
"""Download timestamps from temp storage"""
|
||||
try:
|
||||
from ..core.paths import _find_file
|
||||
|
||||
# Search for file in temp directory
|
||||
file_path = await _find_file(
|
||||
filename=filename, search_paths=[settings.temp_file_dir]
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
file_path,
|
||||
media_type="application/json",
|
||||
filename=filename,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving timestamps file {filename}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to serve timestamps file",
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/dev/captioned_speech")
|
||||
async def create_captioned_speech(
|
||||
request: CaptionedSpeechRequest,
|
||||
client_request: Request,
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
) -> StreamingResponse:
|
||||
"""Generate audio with word-level timestamps using Kokoro's output"""
|
||||
):
|
||||
"""Generate audio with word-level timestamps using streaming approach"""
|
||||
try:
|
||||
# Get voice path
|
||||
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
"aac": "audio/aac",
|
||||
"flac": "audio/flac",
|
||||
"wav": "audio/wav",
|
||||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
# Generate audio with timestamps
|
||||
audio, _, word_timestamps = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
return_timestamps=True,
|
||||
)
|
||||
|
||||
# Create streaming audio writer
|
||||
# Create streaming audio writer and normalizer
|
||||
writer = StreamingAudioWriter(
|
||||
format=request.response_format, sample_rate=24000, channels=1
|
||||
)
|
||||
normalizer = AudioNormalizer()
|
||||
|
||||
# Get voice path
|
||||
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = request.lang_code if request.lang_code else request.voice[0].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
||||
)
|
||||
|
||||
# Get backend and pipeline
|
||||
backend = tts_service.model_manager.get_backend()
|
||||
pipeline = backend._get_pipeline(pipeline_lang_code)
|
||||
|
||||
# Create temp file writer for timestamps
|
||||
temp_writer = TempFileWriter("json")
|
||||
await temp_writer.__aenter__() # Initialize temp file
|
||||
# Get just the filename without the path
|
||||
timestamps_filename = Path(temp_writer.download_path).name
|
||||
|
||||
# Initialize variables for timestamps
|
||||
word_timestamps = []
|
||||
current_offset = 0.0
|
||||
|
||||
async def generate_chunks():
|
||||
nonlocal current_offset, word_timestamps
|
||||
try:
|
||||
if audio is not None:
|
||||
# Normalize audio before writing
|
||||
normalized_audio = await normalizer.normalize(audio)
|
||||
# Write chunk and yield bytes
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(request.input):
|
||||
# Process chunk with pipeline
|
||||
for result in pipeline(chunk_text, voice=voice_path, speed=request.speed):
|
||||
if result.audio is not None:
|
||||
# Process timestamps for this chunk
|
||||
if hasattr(result, "tokens") and result.tokens and result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
if not all(
|
||||
hasattr(token, attr)
|
||||
for attr in ["text", "start_ts", "end_ts"]
|
||||
):
|
||||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
# Apply offset to timestamps
|
||||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": str(token.text).strip(),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Update offset for next chunk
|
||||
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
|
||||
current_offset = max(current_offset + chunk_duration, end_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process timestamps for chunk: {e}")
|
||||
|
||||
# Process audio
|
||||
audio_chunk = result.audio.numpy()
|
||||
normalized_audio = await normalizer.normalize(audio_chunk)
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
# Finalize and yield remaining bytes
|
||||
# Write timestamps to temp file
|
||||
timestamps_json = json.dumps(word_timestamps)
|
||||
await temp_writer.write(timestamps_json.encode())
|
||||
await temp_writer.finalize()
|
||||
|
||||
# Finalize audio
|
||||
final_bytes = writer.write_chunk(finalize=True)
|
||||
if final_bytes:
|
||||
yield final_bytes
|
||||
else:
|
||||
raise ValueError("Failed to generate audio data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
|
||||
# Convert timestamps to JSON and add as header
|
||||
import json
|
||||
|
||||
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
|
||||
timestamps_json = json.dumps(
|
||||
[
|
||||
{
|
||||
"word": str(ts["word"]), # Ensure string for text
|
||||
"start_time": float(
|
||||
ts["start_time"]
|
||||
), # Ensure float for timestamps
|
||||
"end_time": float(ts["end_time"]),
|
||||
}
|
||||
for ts in word_timestamps
|
||||
]
|
||||
)
|
||||
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
|
||||
|
||||
return StreamingResponse(
|
||||
generate_chunks(),
|
||||
media_type=f"audio/{request.response_format}",
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"X-Word-Timestamps": timestamps_json,
|
||||
"X-Timestamps-Path": timestamps_filename,
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
@ -235,8 +324,18 @@ async def create_captioned_speech(
|
|||
"type": "invalid_request_error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
|
|
@ -137,7 +137,7 @@ async def stream_audio_chunks(
|
|||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code,
|
||||
lang_code=request.lang_code or request.voice[0],
|
||||
):
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
|
|
31
api/tests/test_development.py
Normal file
31
api/tests/test_development.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import requests
|
||||
|
||||
def test_generate_captioned_speech():
|
||||
"""Test the generate_captioned_speech function with mocked responses"""
|
||||
# Mock the API responses
|
||||
mock_audio_response = MagicMock()
|
||||
mock_audio_response.status_code = 200
|
||||
mock_audio_response.content = b"mock audio data"
|
||||
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
|
||||
|
||||
mock_timestamps_response = MagicMock()
|
||||
mock_timestamps_response.status_code = 200
|
||||
mock_timestamps_response.json.return_value = [
|
||||
{"word": "test", "start_time": 0.0, "end_time": 1.0}
|
||||
]
|
||||
|
||||
# Patch both HTTP requests
|
||||
with patch('requests.post', return_value=mock_audio_response), \
|
||||
patch('requests.get', return_value=mock_timestamps_response):
|
||||
|
||||
# Import here to avoid module-level import issues
|
||||
from examples.captioned_speech_example import generate_captioned_speech
|
||||
|
||||
# Test the function
|
||||
audio, timestamps = generate_captioned_speech("test text")
|
||||
|
||||
# Verify we got both audio and timestamps
|
||||
assert audio == b"mock audio data"
|
||||
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
|
@ -33,9 +33,19 @@ def generate_captioned_speech(
|
|||
return None, None
|
||||
|
||||
try:
|
||||
# Get timestamps from header
|
||||
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
|
||||
word_timestamps = json.loads(timestamps_json)
|
||||
# Get timestamps path from header
|
||||
timestamps_filename = response.headers.get('X-Timestamps-Path')
|
||||
if not timestamps_filename:
|
||||
print("Error: No timestamps path in response headers")
|
||||
return None, None
|
||||
|
||||
# Get timestamps from the path
|
||||
timestamps_response = requests.get(f"http://localhost:8880/dev/timestamps/{timestamps_filename}")
|
||||
if timestamps_response.status_code != 200:
|
||||
print(f"Error getting timestamps: {timestamps_response.text}")
|
||||
return None, None
|
||||
|
||||
word_timestamps = timestamps_response.json()
|
||||
|
||||
# Get audio bytes from content
|
||||
audio_bytes = response.content
|
||||
|
@ -48,6 +58,9 @@ def generate_captioned_speech(
|
|||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing timestamps: {e}")
|
||||
return None, None
|
||||
except requests.RequestException as e:
|
||||
print(f"Error retrieving timestamps: {e}")
|
||||
return None, None
|
||||
|
||||
def main():
|
||||
# Example texts to convert
|
||||
|
|
Loading…
Add table
Reference in a new issue