mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Cleaned up some code and fixed an error in the readme
This commit is contained in:
parent
34acb17682
commit
9c1ced237b
3 changed files with 3 additions and 400 deletions
|
@ -342,7 +342,7 @@ Key Performance Metrics:
|
|||
<summary>GPU Vs. CPU</summary>
|
||||
|
||||
```bash
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x-100x realtime speed)
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.8 support (~35x-100x realtime speed)
|
||||
cd docker/gpu
|
||||
docker compose up --build
|
||||
|
||||
|
|
|
@ -321,7 +321,7 @@ async def create_captioned_speech(
|
|||
)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
@ -329,157 +329,4 @@ async def create_captioned_speech(
|
|||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
"aac": "audio/aac",
|
||||
"flac": "audio/flac",
|
||||
"wav": "audio/wav",
|
||||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
# Create streaming audio writer and normalizer
|
||||
writer = StreamingAudioWriter(
|
||||
format=request.response_format, sample_rate=24000, channels=1
|
||||
)
|
||||
normalizer = AudioNormalizer()
|
||||
|
||||
# Get voice path
|
||||
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = request.lang_code if request.lang_code else request.voice[0].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
||||
)
|
||||
|
||||
# Get backend and pipeline
|
||||
backend = tts_service.model_manager.get_backend()
|
||||
pipeline = backend._get_pipeline(pipeline_lang_code)
|
||||
|
||||
# Create temp file writer for timestamps
|
||||
temp_writer = TempFileWriter("json")
|
||||
await temp_writer.__aenter__() # Initialize temp file
|
||||
# Get just the filename without the path
|
||||
timestamps_filename = Path(temp_writer.download_path).name
|
||||
|
||||
# Initialize variables for timestamps
|
||||
word_timestamps = []
|
||||
current_offset = 0.0
|
||||
|
||||
async def generate_chunks():
|
||||
nonlocal current_offset, word_timestamps
|
||||
try:
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(request.input):
|
||||
# Process chunk with pipeline
|
||||
for result in pipeline(chunk_text, voice=voice_path, speed=request.speed):
|
||||
if result.audio is not None:
|
||||
# Process timestamps for this chunk
|
||||
if hasattr(result, "tokens") and result.tokens and result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
if not all(
|
||||
hasattr(token, attr)
|
||||
for attr in ["text", "start_ts", "end_ts"]
|
||||
):
|
||||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
# Apply offset to timestamps
|
||||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": str(token.text).strip(),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Update offset for next chunk
|
||||
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
|
||||
current_offset = max(current_offset + chunk_duration, end_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process timestamps for chunk: {e}")
|
||||
|
||||
# Process audio
|
||||
audio_chunk = result.audio.numpy()
|
||||
normalized_audio = await normalizer.normalize(audio_chunk)
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
# Write timestamps to temp file
|
||||
timestamps_json = json.dumps(word_timestamps)
|
||||
await temp_writer.write(timestamps_json.encode())
|
||||
await temp_writer.finalize()
|
||||
|
||||
# Finalize audio
|
||||
final_bytes = writer.write_chunk(finalize=True)
|
||||
if final_bytes:
|
||||
yield final_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
generate_chunks(),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"X-Timestamps-Path": timestamps_filename,
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
"""
|
||||
)
|
|
@ -346,250 +346,6 @@ class TTSService:
|
|||
|
||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data.audio,combined_audio_data
|
||||
"""
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
||||
)
|
||||
|
||||
# Get pipelines from backend for proper device management
|
||||
try:
|
||||
# Initialize quiet pipeline for text chunking
|
||||
text_chunks = []
|
||||
current_offset = 0.0 # Track time offset for timestamps
|
||||
|
||||
logger.debug("Splitting text into chunks...")
|
||||
# Use backend's pipeline management
|
||||
for result in backend._get_pipeline(pipeline_lang_code)(text):
|
||||
if result.graphemes and result.phonemes:
|
||||
text_chunks.append((result.graphemes, result.phonemes))
|
||||
logger.debug(f"Split text into {len(text_chunks)} chunks")
|
||||
|
||||
# Process each chunk
|
||||
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(
|
||||
text_chunks
|
||||
):
|
||||
logger.debug(
|
||||
f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'"
|
||||
)
|
||||
|
||||
# Use backend's pipeline for generation
|
||||
for result in backend._get_pipeline(pipeline_lang_code)(
|
||||
chunk_text, voice=voice_path, speed=speed
|
||||
):
|
||||
# Collect audio chunks
|
||||
if result.audio is not None:
|
||||
chunks.append(result.audio.numpy())
|
||||
|
||||
# Process timestamps for this chunk
|
||||
if (
|
||||
return_timestamps
|
||||
and hasattr(result, "tokens")
|
||||
and result.tokens
|
||||
):
|
||||
logger.debug(
|
||||
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
||||
)
|
||||
if result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(
|
||||
result.tokens, result.pred_dur
|
||||
)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
if not all(
|
||||
hasattr(token, attr)
|
||||
for attr in [
|
||||
"text",
|
||||
"start_ts",
|
||||
"end_ts",
|
||||
]
|
||||
):
|
||||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
# Apply offset to timestamps
|
||||
start_time = (
|
||||
float(token.start_ts) + current_offset
|
||||
)
|
||||
end_time = (
|
||||
float(token.end_ts) + current_offset
|
||||
)
|
||||
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": str(token.text).strip(),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||
)
|
||||
|
||||
# Update offset for next chunk based on pred_dur
|
||||
chunk_duration = (
|
||||
float(result.pred_dur.sum()) / 80
|
||||
) # Convert frames to seconds
|
||||
current_offset = max(
|
||||
current_offset + chunk_duration, end_time
|
||||
)
|
||||
logger.debug(
|
||||
f"Updated time offset to {current_offset:.3f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process timestamps for chunk: {e}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Processing timestamps with pred_dur shape: {result.pred_dur.shape}"
|
||||
)
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(
|
||||
result.tokens, result.pred_dur
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully joined timestamps for chunk"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to join timestamps for chunk: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert tokens to timestamps
|
||||
for token in result.tokens:
|
||||
try:
|
||||
# Skip tokens without required attributes
|
||||
if not all(
|
||||
hasattr(token, attr)
|
||||
for attr in ["text", "start_ts", "end_ts"]
|
||||
):
|
||||
logger.debug(
|
||||
f"Skipping token missing attributes: {dir(token)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get and validate text
|
||||
text = (
|
||||
str(token.text).strip()
|
||||
if token.text is not None
|
||||
else ""
|
||||
)
|
||||
if not text:
|
||||
logger.debug("Skipping empty token")
|
||||
continue
|
||||
|
||||
# Get and validate timestamps
|
||||
start_ts = getattr(token, "start_ts", None)
|
||||
end_ts = getattr(token, "end_ts", None)
|
||||
if start_ts is None or end_ts is None:
|
||||
logger.debug(
|
||||
f"Skipping token with None timestamps: {text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert timestamps to float
|
||||
try:
|
||||
start_time = float(start_ts)
|
||||
end_time = float(end_ts)
|
||||
except (TypeError, ValueError):
|
||||
logger.debug(
|
||||
f"Skipping token with invalid timestamps: {text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add timestamp
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": text,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Added timestamp for word '{text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing token: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process text with pipeline: {e}")
|
||||
raise RuntimeError(f"Pipeline processing failed: {e}")
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
if return_timestamps:
|
||||
# Validate timestamps before returning
|
||||
if not word_timestamps:
|
||||
logger.warning("No valid timestamps were generated")
|
||||
else:
|
||||
# Sort timestamps by start time to ensure proper order
|
||||
word_timestamps.sort(key=lambda x: x["start_time"])
|
||||
# Validate timestamp sequence
|
||||
for i in range(1, len(word_timestamps)):
|
||||
prev = word_timestamps[i - 1]
|
||||
curr = word_timestamps[i]
|
||||
if curr["start_time"] < prev["end_time"]:
|
||||
logger.warning(
|
||||
f"Overlapping timestamps detected: '{prev['word']}' ({prev['start_time']:.3f}-{prev['end_time']:.3f}) and '{curr['word']}' ({curr['start_time']:.3f}-{curr['end_time']:.3f})"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Returning {len(word_timestamps)} word timestamps"
|
||||
)
|
||||
logger.debug(
|
||||
f"First timestamp: {word_timestamps[0]['word']} at {word_timestamps[0]['start_time']:.3f}s"
|
||||
)
|
||||
logger.debug(
|
||||
f"Last timestamp: {word_timestamps[-1]['word']} at {word_timestamps[-1]['end_time']:.3f}s"
|
||||
)
|
||||
|
||||
return audio, processing_time, word_timestamps
|
||||
return audio, processing_time
|
||||
|
||||
else:
|
||||
# For legacy backends
|
||||
async for chunk in self.generate_audio_stream(
|
||||
text,
|
||||
voice,
|
||||
speed, # Default to WAV for raw audio
|
||||
):
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
if return_timestamps:
|
||||
return (
|
||||
audio,
|
||||
processing_time,
|
||||
[],
|
||||
) # Empty timestamps for legacy backends
|
||||
return audio, processing_time
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
|
Loading…
Add table
Reference in a new issue