Update handling in generate_captioned_speech to stream immediately, templink for caption file, and add unit tests for captioned speech generation

This commit is contained in:
remsky 2025-02-09 20:26:59 -07:00
parent a91e0fe9df
commit d73ed87987
5 changed files with 202 additions and 59 deletions

View file

@ -1 +1 @@
v0.2.0
v0.2.1

View file

@ -1,22 +1,28 @@
from typing import List
from typing import List, Union, AsyncGenerator, Tuple
import numpy as np
import torch
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse, FileResponse
from kokoro import KPipeline
from loguru import logger
from ..core.config import settings
from ..services.audio import AudioNormalizer, AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.text_processing import smart_split
from ..services.tts_service import TTSService
from ..services.temp_manager import TempFileWriter
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
)
import json
import os
from pathlib import Path
router = APIRouter(tags=["text processing"])
@ -147,86 +153,169 @@ async def generate_from_phonemes(
)
@router.get("/dev/timestamps/{filename}")
async def get_timestamps(filename: str):
"""Download timestamps from temp storage"""
try:
from ..core.paths import _find_file
# Search for file in temp directory
file_path = await _find_file(
filename=filename, search_paths=[settings.temp_file_dir]
)
return FileResponse(
file_path,
media_type="application/json",
filename=filename,
headers={
"Cache-Control": "no-cache",
"Content-Disposition": f"attachment; filename={filename}",
},
)
except Exception as e:
logger.error(f"Error serving timestamps file {filename}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to serve timestamps file",
"type": "server_error",
},
)
@router.post("/dev/captioned_speech")
async def create_captioned_speech(
request: CaptionedSpeechRequest,
client_request: Request,
x_raw_response: str = Header(None, alias="x-raw-response"),
tts_service: TTSService = Depends(get_tts_service),
) -> StreamingResponse:
"""Generate audio with word-level timestamps using Kokoro's output"""
):
"""Generate audio with word-level timestamps using streaming approach"""
try:
# Get voice path
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
# Set content type based on format
content_type = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
# Generate audio with timestamps
audio, _, word_timestamps = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
speed=request.speed,
return_timestamps=True,
)
# Create streaming audio writer
# Create streaming audio writer and normalizer
writer = StreamingAudioWriter(
format=request.response_format, sample_rate=24000, channels=1
)
normalizer = AudioNormalizer()
# Get voice path
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
# Use provided lang_code or determine from voice name
pipeline_lang_code = request.lang_code if request.lang_code else request.voice[0].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
)
# Get backend and pipeline
backend = tts_service.model_manager.get_backend()
pipeline = backend._get_pipeline(pipeline_lang_code)
# Create temp file writer for timestamps
temp_writer = TempFileWriter("json")
await temp_writer.__aenter__() # Initialize temp file
# Get just the filename without the path
timestamps_filename = Path(temp_writer.download_path).name
# Initialize variables for timestamps
word_timestamps = []
current_offset = 0.0
async def generate_chunks():
nonlocal current_offset, word_timestamps
try:
if audio is not None:
# Normalize audio before writing
normalized_audio = await normalizer.normalize(audio)
# Write chunk and yield bytes
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(request.input):
# Process chunk with pipeline
for result in pipeline(chunk_text, voice=voice_path, speed=request.speed):
if result.audio is not None:
# Process timestamps for this chunk
if hasattr(result, "tokens") and result.tokens and result.pred_dur is not None:
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(result.tokens, result.pred_dur)
# Add timestamps with offset
for token in result.tokens:
if not all(
hasattr(token, attr)
for attr in ["text", "start_ts", "end_ts"]
):
continue
if not token.text or not token.text.strip():
continue
# Apply offset to timestamps
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
{
"word": str(token.text).strip(),
"start_time": start_time,
"end_time": end_time,
}
)
# Update offset for next chunk
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
current_offset = max(current_offset + chunk_duration, end_time)
except Exception as e:
logger.error(f"Failed to process timestamps for chunk: {e}")
# Process audio
audio_chunk = result.audio.numpy()
normalized_audio = await normalizer.normalize(audio_chunk)
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
# Finalize and yield remaining bytes
# Write timestamps to temp file
timestamps_json = json.dumps(word_timestamps)
await temp_writer.write(timestamps_json.encode())
await temp_writer.finalize()
# Finalize audio
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
else:
raise ValueError("Failed to generate audio data")
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
# Clean up writer on error
writer.write_chunk(finalize=True)
await temp_writer.__aexit__(type(e), e, e.__traceback__)
# Re-raise the original exception
raise
# Convert timestamps to JSON and add as header
import json
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
timestamps_json = json.dumps(
[
{
"word": str(ts["word"]), # Ensure string for text
"start_time": float(
ts["start_time"]
), # Ensure float for timestamps
"end_time": float(ts["end_time"]),
}
for ts in word_timestamps
]
)
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
return StreamingResponse(
generate_chunks(),
media_type=f"audio/{request.response_format}",
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Word-Timestamps": timestamps_json,
"X-Timestamps-Path": timestamps_filename,
},
)
except ValueError as e:
logger.error(f"Error in captioned speech generation: {str(e)}")
logger.warning(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={
@ -235,8 +324,18 @@ async def create_captioned_speech(
"type": "invalid_request_error",
},
)
except Exception as e:
logger.error(f"Error in captioned speech generation: {str(e)}")
except RuntimeError as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
)
except Exception as e:
logger.error(f"Unexpected error in speech generation: {str(e)}")
raise HTTPException(
status_code=500,
detail={

View file

@ -137,7 +137,7 @@ async def stream_audio_chunks(
voice=voice_name,
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code,
lang_code=request.lang_code or request.voice[0],
):
# Check if client is still connected
is_disconnected = client_request.is_disconnected

View file

@ -0,0 +1,31 @@
import pytest
from unittest.mock import patch, MagicMock
import requests
def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses"""
# Mock the API responses
mock_audio_response = MagicMock()
mock_audio_response.status_code = 200
mock_audio_response.content = b"mock audio data"
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200
mock_timestamps_response.json.return_value = [
{"word": "test", "start_time": 0.0, "end_time": 1.0}
]
# Patch both HTTP requests
with patch('requests.post', return_value=mock_audio_response), \
patch('requests.get', return_value=mock_timestamps_response):
# Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech
# Test the function
audio, timestamps = generate_captioned_speech("test text")
# Verify we got both audio and timestamps
assert audio == b"mock audio data"
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]

View file

@ -33,9 +33,19 @@ def generate_captioned_speech(
return None, None
try:
# Get timestamps from header
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
word_timestamps = json.loads(timestamps_json)
# Get timestamps path from header
timestamps_filename = response.headers.get('X-Timestamps-Path')
if not timestamps_filename:
print("Error: No timestamps path in response headers")
return None, None
# Get timestamps from the path
timestamps_response = requests.get(f"http://localhost:8880/dev/timestamps/{timestamps_filename}")
if timestamps_response.status_code != 200:
print(f"Error getting timestamps: {timestamps_response.text}")
return None, None
word_timestamps = timestamps_response.json()
# Get audio bytes from content
audio_bytes = response.content
@ -48,6 +58,9 @@ def generate_captioned_speech(
except json.JSONDecodeError as e:
print(f"Error parsing timestamps: {e}")
return None, None
except requests.RequestException as e:
print(f"Error retrieving timestamps: {e}")
return None, None
def main():
# Example texts to convert