Stardust-minus's picture
Upload folder using huggingface_hub
a26769d verified
import math
import typing as tp
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from dac.nn.quantize import ResidualVectorQuantize
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zeros",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right
before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class CausalConvNet(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
groups=1,
padding=None,
):
super(CausalConvNet, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
)
self.stride = stride
self.kernel_size = (kernel_size - 1) * dilation + 1
self.dilation = dilation
self.padding = self.kernel_size - self.stride
def forward(self, x):
pad = self.padding
extra_padding = get_extra_padding_for_conv1d(
x, self.kernel_size, self.stride, pad
)
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
return self.conv(x).contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
class CausalTransConvNet(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
):
super(CausalTransConvNet, self).__init__()
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
)
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
x = self.conv(x)
pad = self.kernel_size - self.stride
padding_right = math.ceil(pad)
padding_left = pad - padding_right
x = unpad1d(x, (padding_left, padding_right))
return x.contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
convnet_type = CausalConvNet
self.dwconv = convnet_type(
dim,
dim,
kernel_size=kernel_size,
# padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
dilation=dilation,
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
if apply_residual:
x = input + x
return x
@dataclass
class VQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
codebook_loss: torch.Tensor
commitment_loss: torch.Tensor
semantic_distill_z: torch.Tensor | None = None
class DownsampleResidualVectorQuantize(nn.Module):
def __init__(
self,
input_dim: int = 1024,
n_codebooks: int = 9,
codebook_dim: int = 8,
quantizer_dropout: float = 0.5,
codebook_size: int = 1024,
semantic_codebook_size: int = 4096,
downsample_factor: tuple[int] = (2, 2),
downsample_dims: tuple[int] | None = None,
pre_module: nn.Module | None = None,
post_module: nn.Module | None = None,
semantic_predictor_module: nn.Module | None = None,
):
super().__init__()
if downsample_dims is None:
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
all_dims = (input_dim,) + tuple(downsample_dims)
self.semantic_quantizer = ResidualVectorQuantize(
input_dim=input_dim,
n_codebooks=1,
codebook_size=semantic_codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=0.0,
)
self.quantizer = ResidualVectorQuantize(
input_dim=input_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.downsample_factor = downsample_factor
self.downsample_dims = downsample_dims
convnet_type = CausalConvNet
transconvnet_type = CausalTransConvNet
self.downsample = nn.Sequential(
*[
nn.Sequential(
convnet_type(
all_dims[idx],
all_dims[idx + 1],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx + 1]),
)
for idx, factor in enumerate(downsample_factor)
]
)
self.upsample = nn.Sequential(
*[
nn.Sequential(
transconvnet_type(
all_dims[idx + 1],
all_dims[idx],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx]),
)
for idx, factor in reversed(list(enumerate(downsample_factor)))
]
)
self.apply(self._init_weights)
self.pre_module = (
pre_module if pre_module is not None else nn.Identity()
) # leave for transformer, LSTM or Mamba or something else
self.post_module = post_module if post_module is not None else nn.Identity()
self.semantic_predictor_module = (
semantic_predictor_module
if semantic_predictor_module is not None
else nn.Identity()
)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(
self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
):
# z: (B, D, T)
original_shape = z.shape
if semantic_len is None:
semantic_len = torch.LongTensor([z.shape[-1]])
z = self.downsample(z)
z = self.pre_module(z) # B, T, D
(
semantic_z,
semantic_codes,
semantic_latents,
semantic_commitment_loss,
semantic_codebook_loss,
) = self.semantic_quantizer(z)
residual_z = z - semantic_z
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
residual_z, n_quantizers=n_quantizers
)
z = semantic_z + residual_z
commitment_loss = commitment_loss + semantic_commitment_loss
codebook_loss = codebook_loss + semantic_codebook_loss
codes = torch.cat([semantic_codes, codes], dim=1)
latents = torch.cat([semantic_latents, latents], dim=1)
z = self.post_module(z)
z = self.upsample(z)
# z: (B, D, T)
# semantic distillation (disabled here since only used in training)
# semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
# Pad or crop z to match original shape
diff = original_shape[-1] - z.shape[-1]
right = 0
left = abs(diff) - right
if diff > 0:
z = F.pad(z, (left, right))
elif diff < 0:
z = z[..., left:]
results = VQResult(
z=z,
codes=codes,
latents=latents,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
)
return results
# def encode(self, z):
# z = self.downsample(z)
# z = self.pre_module(z)
# _, indices, _, _, _ = self.quantizer(z.mT)
# indices = rearrange(indices, "g b l r -> b (g r) l")
# return indices
#
def decode(self, indices: torch.Tensor):
# indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
# print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}")
new_indices = torch.zeros_like(indices)
new_indices[:, 0] = torch.clamp(
indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
)
new_indices[:, 1:] = torch.clamp(
indices[:, 1:], max=self.quantizer.codebook_size - 1
)
z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
z_q = z_q_semantic + z_q_residual
z_q = self.post_module(z_q)
z_q = self.upsample(z_q)
return z_q
# def from_latents(self, latents: torch.Tensor):
# z_q, z_p, codes = super().from_latents(latents)
# z_q = self.upsample(z_q)
# return z_q, z_p, codes
if __name__ == "__main__":
rvq = DownsampleResidualVectorQuantize(
input_dim=512,
n_codebooks=8,
codebook_dim=8,
codebook_size=1024,
quantizer_dropout=0.5,
downsample_factor=[2, 2],
)
rvq.eval()
x = torch.randn(2, 512, 442)
result = rvq(x)
print(rvq)
print(result.latents.shape, result.codes.shape, result.z.shape)
# y = rvq.from_codes(result.codes)
# print(y[0].shape)
# y = rvq.from_latents(
result1 = rvq(x[:, :, :40])
print(result1.latents.shape, result1.codes.shape, result1.z.shape)
assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
print("Success")