Spaces:
Running
on
Zero
Running
on
Zero
import typing as tp | |
import torch | |
def build_delay_indices( | |
B: int, T: int, C: int, delay_pattern: tp.List[int] | |
) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c]. | |
Negative t_idx => BOS; t_idx >= T => PAD. | |
""" | |
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32) | |
t_idx_BxT = torch.broadcast_to( | |
torch.arange(T, dtype=torch.int32)[None, :], | |
[B, T], | |
) | |
t_idx_BxTx1 = t_idx_BxT[..., None] | |
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C) | |
b_idx_BxTxC = torch.broadcast_to( | |
torch.arange(B, dtype=torch.int32).view(B, 1, 1), | |
[B, T, C], | |
) | |
c_idx_BxTxC = torch.broadcast_to( | |
torch.arange(C, dtype=torch.int32).view(1, 1, C), | |
[B, T, C], | |
) | |
# We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail | |
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1) | |
indices_BTCx3 = torch.stack( | |
[ | |
b_idx_BxTxC.reshape(-1), | |
t_clamped_BxTxC.reshape(-1), | |
c_idx_BxTxC.reshape(-1), | |
], | |
dim=1, | |
).long() # Ensure indices are long type for indexing | |
return t_idx_BxTxC, indices_BTCx3 | |
def apply_audio_delay( | |
audio_BxTxC: torch.Tensor, | |
pad_value: int, | |
bos_value: int, | |
precomp: tp.Tuple[torch.Tensor, torch.Tensor], | |
) -> torch.Tensor: | |
""" | |
Applies the delay pattern to batched audio tokens using precomputed indices, | |
inserting BOS where t_idx < 0 and PAD where t_idx >= T. | |
Args: | |
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float) | |
pad_value: the padding token | |
bos_value: the BOS token | |
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices | |
Returns: | |
result_BxTxC: [B, T, C] delayed audio tokens | |
""" | |
device = audio_BxTxC.device # Get device from input tensor | |
t_idx_BxTxC, indices_BTCx3 = precomp | |
t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device | |
indices_BTCx3 = indices_BTCx3.to(device) | |
# Equivalent of tf.gather_nd using advanced indexing | |
# Ensure indices are long type if not already (build_delay_indices should handle this) | |
gathered_flat = audio_BxTxC[ | |
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2] | |
] | |
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape) | |
# Create masks on the correct device | |
mask_bos = t_idx_BxTxC < 0 # => place bos_value | |
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value | |
# Create scalar tensors on the correct device | |
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device) | |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) | |
# If mask_bos, BOS; else if mask_pad, PAD; else original gather | |
# All tensors should now be on the same device | |
result_BxTxC = torch.where( | |
mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC) | |
) | |
return result_BxTxC | |
def build_revert_indices( | |
B: int, T: int, C: int, delay_pattern: tp.List[int] | |
) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Precompute indices for the revert operation using PyTorch. | |
Returns: | |
A tuple (t_idx_BxTxC, indices_BTCx3) where: | |
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay. | |
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from: | |
batch indices, clamped time indices, and channel indices. | |
""" | |
# Use default device unless specified otherwise; assumes inputs might define device later | |
device = None # Or determine dynamically if needed, e.g., from a model parameter | |
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device) | |
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T]) | |
t_idx_BT1 = t_idx_BT1.unsqueeze(-1) | |
t_idx_BxTxC = torch.minimum( | |
t_idx_BT1 + delay_arr.view(1, 1, C), | |
torch.tensor(T - 1, device=device), | |
) | |
b_idx_BxTxC = torch.broadcast_to( | |
torch.arange(B, device=device).view(B, 1, 1), [B, T, C] | |
) | |
c_idx_BxTxC = torch.broadcast_to( | |
torch.arange(C, device=device).view(1, 1, C), [B, T, C] | |
) | |
indices_BTCx3 = torch.stack( | |
[ | |
b_idx_BxTxC.reshape(-1), | |
t_idx_BxTxC.reshape(-1), | |
c_idx_BxTxC.reshape(-1), | |
], | |
axis=1, | |
).long() # Ensure indices are long type | |
return t_idx_BxTxC, indices_BTCx3 | |
def revert_audio_delay( | |
audio_BxTxC: torch.Tensor, | |
pad_value: int, | |
precomp: tp.Tuple[torch.Tensor, torch.Tensor], | |
T: int, | |
) -> torch.Tensor: | |
""" | |
Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version). | |
Args: | |
audio_BxTxC: Input delayed audio tensor | |
pad_value: Padding value for out-of-bounds indices | |
precomp: Precomputed revert indices tuple containing: | |
- t_idx_BxTxC: Time offset indices tensor | |
- indices_BTCx3: Gather indices tensor for original audio | |
T: Original sequence length before padding | |
Returns: | |
Reverted audio tensor with same shape as input | |
""" | |
t_idx_BxTxC, indices_BTCx3 = precomp | |
device = audio_BxTxC.device # Get device from input tensor | |
# Move precomputed indices to the same device as audio_BxTxC if they aren't already | |
t_idx_BxTxC = t_idx_BxTxC.to(device) | |
indices_BTCx3 = indices_BTCx3.to(device) | |
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent) | |
gathered_flat = audio_BxTxC[ | |
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2] | |
] | |
gathered_BxTxC = gathered_flat.view( | |
audio_BxTxC.size() | |
) # Use .size() for robust reshaping | |
# Create pad_tensor on the correct device | |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) | |
# Create T tensor on the correct device for comparison | |
T_tensor = torch.tensor(T, device=device) | |
result_BxTxC = torch.where( | |
t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC | |
) # Changed np.where to torch.where | |
return result_BxTxC | |
def decode( | |
model, | |
audio_codes, | |
): | |
""" | |
Decodes the given frames into an output audio waveform | |
""" | |
if len(audio_codes) != 1: | |
raise ValueError(f"Expected one frame, got {len(audio_codes)}") | |
try: | |
audio_values = model.quantizer.from_codes(audio_codes) | |
audio_values = model.decode(audio_values[0]) | |
return audio_values | |
except Exception as e: | |
print(f"Error in decode method: {str(e)}") | |
raise | |