Spaces:
Running
on
L4
Running
on
L4
import math | |
import typing as tp | |
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dac.nn.quantize import ResidualVectorQuantize | |
from torch.nn.utils.parametrizations import weight_norm | |
from torch.nn.utils.parametrize import remove_parametrizations | |
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
"""Remove padding from x, handling properly zero padding. Only for 1d!""" | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
assert (padding_left + padding_right) <= x.shape[-1] | |
end = x.shape[-1] - padding_right | |
return x[..., padding_left:end] | |
def get_extra_padding_for_conv1d( | |
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
) -> int: | |
"""See `pad_for_conv1d`.""" | |
length = x.shape[-1] | |
n_frames = (length - kernel_size + padding_total) / stride + 1 | |
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
return ideal_length - length | |
def pad1d( | |
x: torch.Tensor, | |
paddings: tp.Tuple[int, int], | |
mode: str = "zeros", | |
value: float = 0.0, | |
): | |
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
If this is the case, we insert extra 0 padding to the right | |
before the reflection happen. | |
""" | |
length = x.shape[-1] | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
if mode == "reflect": | |
max_pad = max(padding_left, padding_right) | |
extra_pad = 0 | |
if length <= max_pad: | |
extra_pad = max_pad - length + 1 | |
x = F.pad(x, (0, extra_pad)) | |
padded = F.pad(x, paddings, mode, value) | |
end = padded.shape[-1] - extra_pad | |
return padded[..., :end] | |
else: | |
return F.pad(x, paddings, mode, value) | |
class CausalConvNet(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
dilation=1, | |
stride=1, | |
groups=1, | |
padding=None, | |
): | |
super(CausalConvNet, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
dilation=dilation, | |
groups=groups, | |
) | |
self.stride = stride | |
self.kernel_size = (kernel_size - 1) * dilation + 1 | |
self.dilation = dilation | |
self.padding = self.kernel_size - self.stride | |
def forward(self, x): | |
pad = self.padding | |
extra_padding = get_extra_padding_for_conv1d( | |
x, self.kernel_size, self.stride, pad | |
) | |
x = pad1d(x, (pad, extra_padding), mode="constant", value=0) | |
return self.conv(x).contiguous() | |
def weight_norm(self, name="weight", dim=0): | |
self.conv = weight_norm(self.conv, name=name, dim=dim) | |
return self | |
def remove_weight_norm(self): | |
self.conv = remove_parametrizations(self.conv) | |
return self | |
class CausalTransConvNet(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None | |
): | |
super(CausalTransConvNet, self).__init__() | |
self.conv = nn.ConvTranspose1d( | |
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation | |
) | |
self.stride = stride | |
self.kernel_size = kernel_size | |
def forward(self, x): | |
x = self.conv(x) | |
pad = self.kernel_size - self.stride | |
padding_right = math.ceil(pad) | |
padding_left = pad - padding_right | |
x = unpad1d(x, (padding_left, padding_right)) | |
return x.contiguous() | |
def weight_norm(self, name="weight", dim=0): | |
self.conv = weight_norm(self.conv, name=name, dim=dim) | |
return self | |
def remove_weight_norm(self): | |
self.conv = remove_parametrizations(self.conv) | |
return self | |
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py | |
class ConvNeXtBlock(nn.Module): | |
r"""ConvNeXt Block. There are two equivalent implementations: | |
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |
We use (2) as we find it slightly faster in PyTorch | |
Args: | |
dim (int): Number of input channels. | |
drop_path (float): Stochastic depth rate. Default: 0.0 | |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. | |
kernel_size (int): Kernel size for depthwise conv. Default: 7. | |
dilation (int): Dilation for depthwise conv. Default: 1. | |
""" # noqa: E501 | |
def __init__( | |
self, | |
dim: int, | |
layer_scale_init_value: float = 1e-6, | |
mlp_ratio: float = 4.0, | |
kernel_size: int = 7, | |
dilation: int = 1, | |
): | |
super().__init__() | |
convnet_type = CausalConvNet | |
self.dwconv = convnet_type( | |
dim, | |
dim, | |
kernel_size=kernel_size, | |
# padding=int(dilation * (kernel_size - 1) / 2), | |
groups=dim, | |
dilation=dilation, | |
) # depthwise conv | |
self.norm = nn.LayerNorm(dim, eps=1e-6) | |
self.pwconv1 = nn.Linear( | |
dim, int(mlp_ratio * dim) | |
) # pointwise/1x1 convs, implemented with linear layers | |
self.act = nn.GELU() | |
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) | |
self.gamma = ( | |
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
if layer_scale_init_value > 0 | |
else None | |
) | |
def forward(self, x, apply_residual: bool = True): | |
input = x | |
x = self.dwconv(x) | |
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.pwconv2(x) | |
if self.gamma is not None: | |
x = self.gamma * x | |
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) | |
if apply_residual: | |
x = input + x | |
return x | |
class VQResult: | |
z: torch.Tensor | |
codes: torch.Tensor | |
latents: torch.Tensor | |
codebook_loss: torch.Tensor | |
commitment_loss: torch.Tensor | |
semantic_distill_z: torch.Tensor | None = None | |
class DownsampleResidualVectorQuantize(nn.Module): | |
def __init__( | |
self, | |
input_dim: int = 1024, | |
n_codebooks: int = 9, | |
codebook_dim: int = 8, | |
quantizer_dropout: float = 0.5, | |
codebook_size: int = 1024, | |
semantic_codebook_size: int = 4096, | |
downsample_factor: tuple[int] = (2, 2), | |
downsample_dims: tuple[int] | None = None, | |
pre_module: nn.Module | None = None, | |
post_module: nn.Module | None = None, | |
semantic_predictor_module: nn.Module | None = None, | |
): | |
super().__init__() | |
if downsample_dims is None: | |
downsample_dims = [input_dim for _ in range(len(downsample_factor))] | |
all_dims = (input_dim,) + tuple(downsample_dims) | |
self.semantic_quantizer = ResidualVectorQuantize( | |
input_dim=input_dim, | |
n_codebooks=1, | |
codebook_size=semantic_codebook_size, | |
codebook_dim=codebook_dim, | |
quantizer_dropout=0.0, | |
) | |
self.quantizer = ResidualVectorQuantize( | |
input_dim=input_dim, | |
n_codebooks=n_codebooks, | |
codebook_size=codebook_size, | |
codebook_dim=codebook_dim, | |
quantizer_dropout=quantizer_dropout, | |
) | |
self.downsample_factor = downsample_factor | |
self.downsample_dims = downsample_dims | |
convnet_type = CausalConvNet | |
transconvnet_type = CausalTransConvNet | |
self.downsample = nn.Sequential( | |
*[ | |
nn.Sequential( | |
convnet_type( | |
all_dims[idx], | |
all_dims[idx + 1], | |
kernel_size=factor, | |
stride=factor, | |
), | |
ConvNeXtBlock(dim=all_dims[idx + 1]), | |
) | |
for idx, factor in enumerate(downsample_factor) | |
] | |
) | |
self.upsample = nn.Sequential( | |
*[ | |
nn.Sequential( | |
transconvnet_type( | |
all_dims[idx + 1], | |
all_dims[idx], | |
kernel_size=factor, | |
stride=factor, | |
), | |
ConvNeXtBlock(dim=all_dims[idx]), | |
) | |
for idx, factor in reversed(list(enumerate(downsample_factor))) | |
] | |
) | |
self.apply(self._init_weights) | |
self.pre_module = ( | |
pre_module if pre_module is not None else nn.Identity() | |
) # leave for transformer, LSTM or Mamba or something else | |
self.post_module = post_module if post_module is not None else nn.Identity() | |
self.semantic_predictor_module = ( | |
semantic_predictor_module | |
if semantic_predictor_module is not None | |
else nn.Identity() | |
) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
def forward( | |
self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs | |
): | |
# z: (B, D, T) | |
original_shape = z.shape | |
if semantic_len is None: | |
semantic_len = torch.LongTensor([z.shape[-1]]) | |
z = self.downsample(z) | |
z = self.pre_module(z) # B, T, D | |
( | |
semantic_z, | |
semantic_codes, | |
semantic_latents, | |
semantic_commitment_loss, | |
semantic_codebook_loss, | |
) = self.semantic_quantizer(z) | |
residual_z = z - semantic_z | |
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer( | |
residual_z, n_quantizers=n_quantizers | |
) | |
z = semantic_z + residual_z | |
commitment_loss = commitment_loss + semantic_commitment_loss | |
codebook_loss = codebook_loss + semantic_codebook_loss | |
codes = torch.cat([semantic_codes, codes], dim=1) | |
latents = torch.cat([semantic_latents, latents], dim=1) | |
z = self.post_module(z) | |
z = self.upsample(z) | |
# z: (B, D, T) | |
# semantic distillation (disabled here since only used in training) | |
# semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D | |
# Pad or crop z to match original shape | |
diff = original_shape[-1] - z.shape[-1] | |
right = 0 | |
left = abs(diff) - right | |
if diff > 0: | |
z = F.pad(z, (left, right)) | |
elif diff < 0: | |
z = z[..., left:] | |
results = VQResult( | |
z=z, | |
codes=codes, | |
latents=latents, | |
commitment_loss=commitment_loss, | |
codebook_loss=codebook_loss, | |
) | |
return results | |
# def encode(self, z): | |
# z = self.downsample(z) | |
# z = self.pre_module(z) | |
# _, indices, _, _, _ = self.quantizer(z.mT) | |
# indices = rearrange(indices, "g b l r -> b (g r) l") | |
# return indices | |
# | |
def decode(self, indices: torch.Tensor): | |
# indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) | |
# print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}") | |
new_indices = torch.zeros_like(indices) | |
new_indices[:, 0] = torch.clamp( | |
indices[:, 0], max=self.semantic_quantizer.codebook_size - 1 | |
) | |
new_indices[:, 1:] = torch.clamp( | |
indices[:, 1:], max=self.quantizer.codebook_size - 1 | |
) | |
z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0] | |
z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0] | |
z_q = z_q_semantic + z_q_residual | |
z_q = self.post_module(z_q) | |
z_q = self.upsample(z_q) | |
return z_q | |
# def from_latents(self, latents: torch.Tensor): | |
# z_q, z_p, codes = super().from_latents(latents) | |
# z_q = self.upsample(z_q) | |
# return z_q, z_p, codes | |
if __name__ == "__main__": | |
rvq = DownsampleResidualVectorQuantize( | |
input_dim=512, | |
n_codebooks=8, | |
codebook_dim=8, | |
codebook_size=1024, | |
quantizer_dropout=0.5, | |
downsample_factor=[2, 2], | |
) | |
rvq.eval() | |
x = torch.randn(2, 512, 442) | |
result = rvq(x) | |
print(rvq) | |
print(result.latents.shape, result.codes.shape, result.z.shape) | |
# y = rvq.from_codes(result.codes) | |
# print(y[0].shape) | |
# y = rvq.from_latents( | |
result1 = rvq(x[:, :, :40]) | |
print(result1.latents.shape, result1.codes.shape, result1.z.shape) | |
assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8) | |
print("Success") | |