mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor voice tensor handling to support multiple formats in KokoroV1Wrapper
This commit is contained in:
parent
9a588a3483
commit
e19ba3a8ce
1 changed files with 34 additions and 16 deletions
|
@ -56,16 +56,25 @@ class KokoroV1Wrapper:
|
|||
# Validate and reshape voice tensor
|
||||
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# Ensure tensor has correct number of dimensions
|
||||
if voice_tensor.dim() == 1:
|
||||
voice_tensor = voice_tensor.unsqueeze(0) # [N] -> [1, N]
|
||||
if voice_tensor.dim() == 2:
|
||||
voice_tensor = voice_tensor.unsqueeze(1) # [B, N] -> [B, 1, N]
|
||||
# Handle different voice tensor formats
|
||||
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# For v0.19 format: [510, 1, 256]
|
||||
if voice_tensor.dim() == 3 and voice_tensor.size(1) == 1:
|
||||
# Select embedding based on text length
|
||||
voice_tensor = voice_tensor[len(text)-1] # [510, 1, 256] -> [1, 256]
|
||||
|
||||
# For v1.0 format: [510, 256]
|
||||
elif voice_tensor.dim() == 2 and voice_tensor.size(-1) == 256:
|
||||
# Select embedding based on text length
|
||||
voice_tensor = voice_tensor[len(text)-1].unsqueeze(0) # [510, 256] -> [1, 256]
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# Validate feature dimension
|
||||
if voice_tensor.size(-1) != 256: # Expected size for style + content
|
||||
raise RuntimeError(f"Voice tensor has wrong feature size: expected 256, got {voice_tensor.size(-1)}")
|
||||
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# Generate audio directly using KModel
|
||||
audio = self.model.forward(
|
||||
|
@ -111,16 +120,25 @@ class KokoroV1Wrapper:
|
|||
# Validate and reshape voice tensor
|
||||
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# Ensure tensor has correct number of dimensions
|
||||
if voice_tensor.dim() == 1:
|
||||
voice_tensor = voice_tensor.unsqueeze(0) # [N] -> [1, N]
|
||||
if voice_tensor.dim() == 2:
|
||||
voice_tensor = voice_tensor.unsqueeze(1) # [B, N] -> [B, 1, N]
|
||||
# Handle different voice tensor formats
|
||||
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# For v0.19 format: [510, 1, 256]
|
||||
if voice_tensor.dim() == 3 and voice_tensor.size(1) == 1:
|
||||
# Select embedding based on text length
|
||||
voice_tensor = voice_tensor[len(text)-1] # [510, 1, 256] -> [1, 256]
|
||||
|
||||
# For v1.0 format: [510, 256]
|
||||
elif voice_tensor.dim() == 2 and voice_tensor.size(-1) == 256:
|
||||
# Select embedding based on text length
|
||||
voice_tensor = voice_tensor[len(text)-1].unsqueeze(0) # [510, 256] -> [1, 256]
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
# Validate feature dimension
|
||||
if voice_tensor.size(-1) != 256: # Expected size for style + content
|
||||
raise RuntimeError(f"Voice tensor has wrong feature size: expected 256, got {voice_tensor.size(-1)}")
|
||||
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
|
||||
|
||||
try:
|
||||
# Generate audio directly using KModel
|
||||
|
|
Loading…
Add table
Reference in a new issue