Spaces:
Running
Running
import math | |
from typing import List | |
from typing import Union | |
import numpy as np | |
import torch | |
from audiotools import AudioSignal | |
from audiotools.ml import BaseModel | |
from torch import nn | |
from .base import CodecMixin | |
from dac.nn.layers import Snake1d | |
from dac.nn.layers import WNConv1d | |
from dac.nn.layers import WNConvTranspose1d | |
from dac.nn.quantize import ResidualVectorQuantize | |
from .encodec import SConv1d, SConvTranspose1d, SLSTM | |
def init_weights(m): | |
if isinstance(m, nn.Conv1d): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
class ResidualUnit(nn.Module): | |
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): | |
super().__init__() | |
conv1d_type = SConv1d# if causal else WNConv1d | |
pad = ((7 - 1) * dilation) // 2 | |
self.block = nn.Sequential( | |
Snake1d(dim), | |
conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'), | |
Snake1d(dim), | |
conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'), | |
) | |
def forward(self, x): | |
y = self.block(x) | |
pad = (x.shape[-1] - y.shape[-1]) // 2 | |
if pad > 0: | |
x = x[..., pad:-pad] | |
return x + y | |
class EncoderBlock(nn.Module): | |
def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False): | |
super().__init__() | |
conv1d_type = SConv1d# if causal else WNConv1d | |
self.block = nn.Sequential( | |
ResidualUnit(dim // 2, dilation=1, causal=causal), | |
ResidualUnit(dim // 2, dilation=3, causal=causal), | |
ResidualUnit(dim // 2, dilation=9, causal=causal), | |
Snake1d(dim // 2), | |
conv1d_type( | |
dim // 2, | |
dim, | |
kernel_size=2 * stride, | |
stride=stride, | |
padding=math.ceil(stride / 2), | |
causal=causal, | |
norm='weight_norm', | |
), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
d_model: int = 64, | |
strides: list = [2, 4, 8, 8], | |
d_latent: int = 64, | |
causal: bool = False, | |
lstm: int = 2, | |
): | |
super().__init__() | |
conv1d_type = SConv1d# if causal else WNConv1d | |
# Create first convolution | |
self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] | |
# Create EncoderBlocks that double channels as they downsample by `stride` | |
for stride in strides: | |
d_model *= 2 | |
self.block += [EncoderBlock(d_model, stride=stride, causal=causal)] | |
# Add LSTM if needed | |
self.use_lstm = lstm | |
if lstm: | |
self.block += [SLSTM(d_model, lstm)] | |
# Create last convolution | |
self.block += [ | |
Snake1d(d_model), | |
conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'), | |
] | |
# Wrap black into nn.Sequential | |
self.block = nn.Sequential(*self.block) | |
self.enc_dim = d_model | |
def forward(self, x): | |
return self.block(x) | |
def reset_cache(self): | |
# recursively find all submodules named SConv1d in self.block and use their reset_cache method | |
def reset_cache(m): | |
if isinstance(m, SConv1d) or isinstance(m, SLSTM): | |
m.reset_cache() | |
return | |
for child in m.children(): | |
reset_cache(child) | |
reset_cache(self.block) | |
class DecoderBlock(nn.Module): | |
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False): | |
super().__init__() | |
conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d | |
self.block = nn.Sequential( | |
Snake1d(input_dim), | |
conv1d_type( | |
input_dim, | |
output_dim, | |
kernel_size=2 * stride, | |
stride=stride, | |
padding=math.ceil(stride / 2), | |
causal=causal, | |
norm='weight_norm' | |
), | |
ResidualUnit(output_dim, dilation=1, causal=causal), | |
ResidualUnit(output_dim, dilation=3, causal=causal), | |
ResidualUnit(output_dim, dilation=9, causal=causal), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
input_channel, | |
channels, | |
rates, | |
d_out: int = 1, | |
causal: bool = False, | |
lstm: int = 2, | |
): | |
super().__init__() | |
conv1d_type = SConv1d# if causal else WNConv1d | |
# Add first conv layer | |
layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] | |
if lstm: | |
layers += [SLSTM(channels, num_layers=lstm)] | |
# Add upsampling + MRF blocks | |
for i, stride in enumerate(rates): | |
input_dim = channels // 2**i | |
output_dim = channels // 2 ** (i + 1) | |
layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)] | |
# Add final conv layer | |
layers += [ | |
Snake1d(output_dim), | |
conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'), | |
nn.Tanh(), | |
] | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class DAC(BaseModel, CodecMixin): | |
def __init__( | |
self, | |
encoder_dim: int = 64, | |
encoder_rates: List[int] = [2, 4, 8, 8], | |
latent_dim: int = None, | |
decoder_dim: int = 1536, | |
decoder_rates: List[int] = [8, 8, 4, 2], | |
n_codebooks: int = 9, | |
codebook_size: int = 1024, | |
codebook_dim: Union[int, list] = 8, | |
quantizer_dropout: bool = False, | |
sample_rate: int = 44100, | |
lstm: int = 2, | |
causal: bool = False, | |
): | |
super().__init__() | |
self.encoder_dim = encoder_dim | |
self.encoder_rates = encoder_rates | |
self.decoder_dim = decoder_dim | |
self.decoder_rates = decoder_rates | |
self.sample_rate = sample_rate | |
if latent_dim is None: | |
latent_dim = encoder_dim * (2 ** len(encoder_rates)) | |
self.latent_dim = latent_dim | |
self.hop_length = np.prod(encoder_rates) | |
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm) | |
self.n_codebooks = n_codebooks | |
self.codebook_size = codebook_size | |
self.codebook_dim = codebook_dim | |
self.quantizer = ResidualVectorQuantize( | |
input_dim=latent_dim, | |
n_codebooks=n_codebooks, | |
codebook_size=codebook_size, | |
codebook_dim=codebook_dim, | |
quantizer_dropout=quantizer_dropout, | |
) | |
self.decoder = Decoder( | |
latent_dim, | |
decoder_dim, | |
decoder_rates, | |
lstm=lstm, | |
causal=causal, | |
) | |
self.sample_rate = sample_rate | |
self.apply(init_weights) | |
self.delay = self.get_delay() | |
def preprocess(self, audio_data, sample_rate): | |
if sample_rate is None: | |
sample_rate = self.sample_rate | |
assert sample_rate == self.sample_rate | |
length = audio_data.shape[-1] | |
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length | |
audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
return audio_data | |
def encode( | |
self, | |
audio_data: torch.Tensor, | |
n_quantizers: int = None, | |
): | |
"""Encode given audio data and return quantized latent codes | |
Parameters | |
---------- | |
audio_data : Tensor[B x 1 x T] | |
Audio data to encode | |
n_quantizers : int, optional | |
Number of quantizers to use, by default None | |
If None, all quantizers are used. | |
Returns | |
------- | |
dict | |
A dictionary with the following keys: | |
"z" : Tensor[B x D x T] | |
Quantized continuous representation of input | |
"codes" : Tensor[B x N x T] | |
Codebook indices for each codebook | |
(quantized discrete representation of input) | |
"latents" : Tensor[B x N*D x T] | |
Projected latents (continuous representation of input before quantization) | |
"vq/commitment_loss" : Tensor[1] | |
Commitment loss to train encoder to predict vectors closer to codebook | |
entries | |
"vq/codebook_loss" : Tensor[1] | |
Codebook loss to update the codebook | |
"length" : int | |
Number of samples in input audio | |
""" | |
z = self.encoder(audio_data) | |
z, codes, latents, commitment_loss, codebook_loss = self.quantizer( | |
z, n_quantizers | |
) | |
return z, codes, latents, commitment_loss, codebook_loss | |
def decode(self, z: torch.Tensor): | |
"""Decode given latent codes and return audio data | |
Parameters | |
---------- | |
z : Tensor[B x D x T] | |
Quantized continuous representation of input | |
length : int, optional | |
Number of samples in output audio, by default None | |
Returns | |
------- | |
dict | |
A dictionary with the following keys: | |
"audio" : Tensor[B x 1 x length] | |
Decoded audio data. | |
""" | |
return self.decoder(z) | |
def forward( | |
self, | |
audio_data: torch.Tensor, | |
sample_rate: int = None, | |
n_quantizers: int = None, | |
): | |
"""Model forward pass | |
Parameters | |
---------- | |
audio_data : Tensor[B x 1 x T] | |
Audio data to encode | |
sample_rate : int, optional | |
Sample rate of audio data in Hz, by default None | |
If None, defaults to `self.sample_rate` | |
n_quantizers : int, optional | |
Number of quantizers to use, by default None. | |
If None, all quantizers are used. | |
Returns | |
------- | |
dict | |
A dictionary with the following keys: | |
"z" : Tensor[B x D x T] | |
Quantized continuous representation of input | |
"codes" : Tensor[B x N x T] | |
Codebook indices for each codebook | |
(quantized discrete representation of input) | |
"latents" : Tensor[B x N*D x T] | |
Projected latents (continuous representation of input before quantization) | |
"vq/commitment_loss" : Tensor[1] | |
Commitment loss to train encoder to predict vectors closer to codebook | |
entries | |
"vq/codebook_loss" : Tensor[1] | |
Codebook loss to update the codebook | |
"length" : int | |
Number of samples in input audio | |
"audio" : Tensor[B x 1 x length] | |
Decoded audio data. | |
""" | |
length = audio_data.shape[-1] | |
audio_data = self.preprocess(audio_data, sample_rate) | |
z, codes, latents, commitment_loss, codebook_loss = self.encode( | |
audio_data, n_quantizers | |
) | |
x = self.decode(z) | |
return { | |
"audio": x[..., :length], | |
"z": z, | |
"codes": codes, | |
"latents": latents, | |
"vq/commitment_loss": commitment_loss, | |
"vq/codebook_loss": codebook_loss, | |
} | |
if __name__ == "__main__": | |
import numpy as np | |
from functools import partial | |
model = DAC().to("cpu") | |
for n, m in model.named_modules(): | |
o = m.extra_repr() | |
p = sum([np.prod(p.size()) for p in m.parameters()]) | |
fn = lambda o, p: o + f" {p/1e6:<.3f}M params." | |
setattr(m, "extra_repr", partial(fn, o=o, p=p)) | |
print(model) | |
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) | |
length = 88200 * 2 | |
x = torch.randn(1, 1, length).to(model.device) | |
x.requires_grad_(True) | |
x.retain_grad() | |
# Make a forward pass | |
out = model(x)["audio"] | |
print("Input shape:", x.shape) | |
print("Output shape:", out.shape) | |
# Create gradient variable | |
grad = torch.zeros_like(out) | |
grad[:, :, grad.shape[-1] // 2] = 1 | |
# Make a backward pass | |
out.backward(grad) | |
# Check non-zero values | |
gradmap = x.grad.squeeze(0) | |
gradmap = (gradmap != 0).sum(0) # sum across features | |
rf = (gradmap != 0).sum() | |
print(f"Receptive field: {rf.item()}") | |
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) | |
model.decompress(model.compress(x, verbose=True), verbose=True) | |