From d73ed8798765c275743fa8e76ed54a7bc3a8802d Mon Sep 17 00:00:00 2001 From: remsky Date: Sun, 9 Feb 2025 20:26:59 -0700 Subject: [PATCH] Update handling in generate_captioned_speech to stream immediately, templink for caption file, and add unit tests for captioned speech generation --- VERSION | 2 +- api/src/routers/development.py | 207 ++++++++++++++++++++------- api/src/routers/openai_compatible.py | 2 +- api/tests/test_development.py | 31 ++++ examples/captioned_speech_example.py | 19 ++- 5 files changed, 202 insertions(+), 59 deletions(-) create mode 100644 api/tests/test_development.py diff --git a/VERSION b/VERSION index 1474d00..22c08f7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v0.2.0 +v0.2.1 diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 79c7585..20c2206 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -1,22 +1,28 @@ -from typing import List +from typing import List, Union, AsyncGenerator, Tuple import numpy as np import torch -from fastapi import APIRouter, Depends, HTTPException, Request, Response -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response +from fastapi.responses import StreamingResponse, FileResponse from kokoro import KPipeline from loguru import logger +from ..core.config import settings from ..services.audio import AudioNormalizer, AudioService from ..services.streaming_audio_writer import StreamingAudioWriter from ..services.text_processing import smart_split from ..services.tts_service import TTSService +from ..services.temp_manager import TempFileWriter from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp from ..structures.text_schemas import ( GenerateFromPhonemesRequest, PhonemeRequest, PhonemeResponse, ) +import json +import os +from pathlib import Path + router = APIRouter(tags=["text processing"]) @@ -147,86 +153,169 @@ async def generate_from_phonemes( ) +@router.get("/dev/timestamps/{filename}") +async def get_timestamps(filename: str): + """Download timestamps from temp storage""" + try: + from ..core.paths import _find_file + + # Search for file in temp directory + file_path = await _find_file( + filename=filename, search_paths=[settings.temp_file_dir] + ) + + return FileResponse( + file_path, + media_type="application/json", + filename=filename, + headers={ + "Cache-Control": "no-cache", + "Content-Disposition": f"attachment; filename={filename}", + }, + ) + + except Exception as e: + logger.error(f"Error serving timestamps file {filename}: {e}") + raise HTTPException( + status_code=500, + detail={ + "error": "server_error", + "message": "Failed to serve timestamps file", + "type": "server_error", + }, + ) + + @router.post("/dev/captioned_speech") async def create_captioned_speech( request: CaptionedSpeechRequest, + client_request: Request, + x_raw_response: str = Header(None, alias="x-raw-response"), tts_service: TTSService = Depends(get_tts_service), -) -> StreamingResponse: - """Generate audio with word-level timestamps using Kokoro's output""" +): + """Generate audio with word-level timestamps using streaming approach""" try: - # Get voice path - voice_name, voice_path = await tts_service._get_voice_path(request.voice) + # Set content type based on format + content_type = { + "mp3": "audio/mpeg", + "opus": "audio/opus", + "aac": "audio/aac", + "flac": "audio/flac", + "wav": "audio/wav", + "pcm": "audio/pcm", + }.get(request.response_format, f"audio/{request.response_format}") - # Generate audio with timestamps - audio, _, word_timestamps = await tts_service.generate_audio( - text=request.input, - voice=voice_name, - speed=request.speed, - return_timestamps=True, - ) - - # Create streaming audio writer + # Create streaming audio writer and normalizer writer = StreamingAudioWriter( format=request.response_format, sample_rate=24000, channels=1 ) normalizer = AudioNormalizer() - async def generate_chunks(): - try: - if audio is not None: - # Normalize audio before writing - normalized_audio = await normalizer.normalize(audio) - # Write chunk and yield bytes - chunk_bytes = writer.write_chunk(normalized_audio) - if chunk_bytes: - yield chunk_bytes + # Get voice path + voice_name, voice_path = await tts_service._get_voice_path(request.voice) - # Finalize and yield remaining bytes - final_bytes = writer.write_chunk(finalize=True) - if final_bytes: - yield final_bytes - else: - raise ValueError("Failed to generate audio data") + # Use provided lang_code or determine from voice name + pipeline_lang_code = request.lang_code if request.lang_code else request.voice[0].lower() + logger.info( + f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking" + ) + + # Get backend and pipeline + backend = tts_service.model_manager.get_backend() + pipeline = backend._get_pipeline(pipeline_lang_code) + + # Create temp file writer for timestamps + temp_writer = TempFileWriter("json") + await temp_writer.__aenter__() # Initialize temp file + # Get just the filename without the path + timestamps_filename = Path(temp_writer.download_path).name + + # Initialize variables for timestamps + word_timestamps = [] + current_offset = 0.0 + + async def generate_chunks(): + nonlocal current_offset, word_timestamps + try: + # Process text in chunks with smart splitting + async for chunk_text, tokens in smart_split(request.input): + # Process chunk with pipeline + for result in pipeline(chunk_text, voice=voice_path, speed=request.speed): + if result.audio is not None: + # Process timestamps for this chunk + if hasattr(result, "tokens") and result.tokens and result.pred_dur is not None: + try: + # Join timestamps for this chunk's tokens + KPipeline.join_timestamps(result.tokens, result.pred_dur) + + # Add timestamps with offset + for token in result.tokens: + if not all( + hasattr(token, attr) + for attr in ["text", "start_ts", "end_ts"] + ): + continue + if not token.text or not token.text.strip(): + continue + + # Apply offset to timestamps + start_time = float(token.start_ts) + current_offset + end_time = float(token.end_ts) + current_offset + + word_timestamps.append( + { + "word": str(token.text).strip(), + "start_time": start_time, + "end_time": end_time, + } + ) + + # Update offset for next chunk + chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds + current_offset = max(current_offset + chunk_duration, end_time) + + except Exception as e: + logger.error(f"Failed to process timestamps for chunk: {e}") + + # Process audio + audio_chunk = result.audio.numpy() + normalized_audio = await normalizer.normalize(audio_chunk) + chunk_bytes = writer.write_chunk(normalized_audio) + if chunk_bytes: + yield chunk_bytes + + # Write timestamps to temp file + timestamps_json = json.dumps(word_timestamps) + await temp_writer.write(timestamps_json.encode()) + await temp_writer.finalize() + + # Finalize audio + final_bytes = writer.write_chunk(finalize=True) + if final_bytes: + yield final_bytes except Exception as e: logger.error(f"Error in audio generation: {str(e)}") # Clean up writer on error writer.write_chunk(finalize=True) + await temp_writer.__aexit__(type(e), e, e.__traceback__) # Re-raise the original exception raise - # Convert timestamps to JSON and add as header - import json - - logger.debug(f"Processing {len(word_timestamps)} word timestamps") - timestamps_json = json.dumps( - [ - { - "word": str(ts["word"]), # Ensure string for text - "start_time": float( - ts["start_time"] - ), # Ensure float for timestamps - "end_time": float(ts["end_time"]), - } - for ts in word_timestamps - ] - ) - logger.debug(f"Generated timestamps JSON: {timestamps_json}") - return StreamingResponse( generate_chunks(), - media_type=f"audio/{request.response_format}", + media_type=content_type, headers={ "Content-Disposition": f"attachment; filename=speech.{request.response_format}", "X-Accel-Buffering": "no", "Cache-Control": "no-cache", "Transfer-Encoding": "chunked", - "X-Word-Timestamps": timestamps_json, + "X-Timestamps-Path": timestamps_filename, }, ) except ValueError as e: - logger.error(f"Error in captioned speech generation: {str(e)}") + logger.warning(f"Invalid request: {str(e)}") raise HTTPException( status_code=400, detail={ @@ -235,8 +324,18 @@ async def create_captioned_speech( "type": "invalid_request_error", }, ) - except Exception as e: - logger.error(f"Error in captioned speech generation: {str(e)}") + except RuntimeError as e: + logger.error(f"Processing error: {str(e)}") + raise HTTPException( + status_code=500, + detail={ + "error": "processing_error", + "message": str(e), + "type": "server_error", + }, + ) + except Exception as e: + logger.error(f"Unexpected error in speech generation: {str(e)}") raise HTTPException( status_code=500, detail={ diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 8c94132..a2506bc 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -137,7 +137,7 @@ async def stream_audio_chunks( voice=voice_name, speed=request.speed, output_format=request.response_format, - lang_code=request.lang_code, + lang_code=request.lang_code or request.voice[0], ): # Check if client is still connected is_disconnected = client_request.is_disconnected diff --git a/api/tests/test_development.py b/api/tests/test_development.py new file mode 100644 index 0000000..a05ba23 --- /dev/null +++ b/api/tests/test_development.py @@ -0,0 +1,31 @@ +import pytest +from unittest.mock import patch, MagicMock +import requests + +def test_generate_captioned_speech(): + """Test the generate_captioned_speech function with mocked responses""" + # Mock the API responses + mock_audio_response = MagicMock() + mock_audio_response.status_code = 200 + mock_audio_response.content = b"mock audio data" + mock_audio_response.headers = {"X-Timestamps-Path": "test.json"} + + mock_timestamps_response = MagicMock() + mock_timestamps_response.status_code = 200 + mock_timestamps_response.json.return_value = [ + {"word": "test", "start_time": 0.0, "end_time": 1.0} + ] + + # Patch both HTTP requests + with patch('requests.post', return_value=mock_audio_response), \ + patch('requests.get', return_value=mock_timestamps_response): + + # Import here to avoid module-level import issues + from examples.captioned_speech_example import generate_captioned_speech + + # Test the function + audio, timestamps = generate_captioned_speech("test text") + + # Verify we got both audio and timestamps + assert audio == b"mock audio data" + assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}] \ No newline at end of file diff --git a/examples/captioned_speech_example.py b/examples/captioned_speech_example.py index 384dcf2..2433134 100644 --- a/examples/captioned_speech_example.py +++ b/examples/captioned_speech_example.py @@ -33,9 +33,19 @@ def generate_captioned_speech( return None, None try: - # Get timestamps from header - timestamps_json = response.headers.get('X-Word-Timestamps', '[]') - word_timestamps = json.loads(timestamps_json) + # Get timestamps path from header + timestamps_filename = response.headers.get('X-Timestamps-Path') + if not timestamps_filename: + print("Error: No timestamps path in response headers") + return None, None + + # Get timestamps from the path + timestamps_response = requests.get(f"http://localhost:8880/dev/timestamps/{timestamps_filename}") + if timestamps_response.status_code != 200: + print(f"Error getting timestamps: {timestamps_response.text}") + return None, None + + word_timestamps = timestamps_response.json() # Get audio bytes from content audio_bytes = response.content @@ -48,6 +58,9 @@ def generate_captioned_speech( except json.JSONDecodeError as e: print(f"Error parsing timestamps: {e}") return None, None + except requests.RequestException as e: + print(f"Error retrieving timestamps: {e}") + return None, None def main(): # Example texts to convert