Refactor voice tensor handling to support multiple formats in KokoroV1Wrapper

This commit is contained in:
remsky 2025-01-31 06:16:04 -07:00
parent 9a588a3483
commit e19ba3a8ce

View file

@ -56,16 +56,25 @@ class KokoroV1Wrapper:
# Validate and reshape voice tensor
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
# Ensure tensor has correct number of dimensions
if voice_tensor.dim() == 1:
voice_tensor = voice_tensor.unsqueeze(0) # [N] -> [1, N]
if voice_tensor.dim() == 2:
voice_tensor = voice_tensor.unsqueeze(1) # [B, N] -> [B, 1, N]
# Handle different voice tensor formats
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
# For v0.19 format: [510, 1, 256]
if voice_tensor.dim() == 3 and voice_tensor.size(1) == 1:
# Select embedding based on text length
voice_tensor = voice_tensor[len(text)-1] # [510, 1, 256] -> [1, 256]
# For v1.0 format: [510, 256]
elif voice_tensor.dim() == 2 and voice_tensor.size(-1) == 256:
# Select embedding based on text length
voice_tensor = voice_tensor[len(text)-1].unsqueeze(0) # [510, 256] -> [1, 256]
else:
raise RuntimeError(f"Unsupported voice tensor shape: {voice_tensor.shape}")
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
# Validate feature dimension
if voice_tensor.size(-1) != 256: # Expected size for style + content
raise RuntimeError(f"Voice tensor has wrong feature size: expected 256, got {voice_tensor.size(-1)}")
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
# Generate audio directly using KModel
audio = self.model.forward(
@ -111,16 +120,25 @@ class KokoroV1Wrapper:
# Validate and reshape voice tensor
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
# Ensure tensor has correct number of dimensions
if voice_tensor.dim() == 1:
voice_tensor = voice_tensor.unsqueeze(0) # [N] -> [1, N]
if voice_tensor.dim() == 2:
voice_tensor = voice_tensor.unsqueeze(1) # [B, N] -> [B, 1, N]
# Handle different voice tensor formats
logger.debug(f"Initial voice tensor shape: {voice_tensor.shape}")
# For v0.19 format: [510, 1, 256]
if voice_tensor.dim() == 3 and voice_tensor.size(1) == 1:
# Select embedding based on text length
voice_tensor = voice_tensor[len(text)-1] # [510, 1, 256] -> [1, 256]
# For v1.0 format: [510, 256]
elif voice_tensor.dim() == 2 and voice_tensor.size(-1) == 256:
# Select embedding based on text length
voice_tensor = voice_tensor[len(text)-1].unsqueeze(0) # [510, 256] -> [1, 256]
else:
raise RuntimeError(f"Unsupported voice tensor shape: {voice_tensor.shape}")
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
# Validate feature dimension
if voice_tensor.size(-1) != 256: # Expected size for style + content
raise RuntimeError(f"Voice tensor has wrong feature size: expected 256, got {voice_tensor.size(-1)}")
logger.debug(f"After reshape voice tensor shape: {voice_tensor.shape}")
try:
# Generate audio directly using KModel