Spaces:
Running
on
L40S
Running
on
L40S
"""! | |
@author Yi Luo (oulyluo) | |
@copyright Tencent AI Lab | |
""" | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.utils.checkpoint import checkpoint_sequential | |
from thop import profile, clever_format | |
class RMVN(nn.Module): | |
""" | |
Rescaled MVN. | |
""" | |
def __init__(self, dimension, groups=1): | |
super(RMVN, self).__init__() | |
self.mean = nn.Parameter(torch.zeros(dimension)) | |
self.std = nn.Parameter(torch.ones(dimension)) | |
self.groups = groups | |
self.eps = torch.finfo(torch.float32).eps | |
def forward(self, input): | |
# input size: (B, N, T) | |
B, N, T = input.shape | |
assert N % self.groups == 0 | |
input = input.view(B, self.groups, -1, T) | |
input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() | |
input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) | |
return input_norm | |
class ConvActNorm1d(nn.Module): | |
def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): | |
super(ConvActNorm1d, self).__init__() | |
self.in_channel = in_channel | |
self.kernel = kernel | |
self.causal = causal | |
if not causal: | |
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), | |
RMVN(in_channel), | |
nn.Conv1d(in_channel, hidden_channel*2, 1), | |
nn.GLU(dim=1), | |
nn.Conv1d(hidden_channel, in_channel, 1) | |
) | |
else: | |
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), | |
RMVN(in_channel), | |
nn.Conv1d(in_channel, hidden_channel*2, 1), | |
nn.GLU(dim=1), | |
nn.Conv1d(hidden_channel, in_channel, 1) | |
) | |
def forward(self, input): | |
output = self.conv(input) | |
if self.causal: | |
output = output[...,:-self.kernel+1].contiguous() | |
return input + output | |
class ICB(nn.Module): | |
def __init__(self, in_channel, kernel=7, causal=False): | |
super(ICB, self).__init__() | |
self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), | |
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), | |
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) | |
) | |
def forward(self, input): | |
return self.blocks(input) | |
class ResRNN(nn.Module): | |
def __init__(self, input_size, hidden_size, bidirectional=False): | |
super(ResRNN, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.eps = torch.finfo(torch.float32).eps | |
self.norm = RMVN(input_size) | |
self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) | |
self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) | |
def forward(self, input, use_head=1): | |
# input shape: batch, dim, seq | |
B, N, T = input.shape | |
rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) | |
output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) | |
output = output.view(B, T, -1).transpose(1,2).contiguous() | |
return input + output | |
class BSNet(nn.Module): | |
def __init__(self, feature_dim, kernel=7, causal=False): | |
super(BSNet, self).__init__() | |
self.feature_dim = feature_dim | |
self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) | |
self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) | |
def forward(self, input): | |
# input shape: B, nband, N, T | |
B, nband, N, T = input.shape | |
band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) | |
# band comm | |
band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) | |
output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() | |
return output.view(B, nband, N, T) | |
# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py | |
class VQEmbeddingEMA(nn.Module): | |
def __init__(self, num_code, code_dim, decay=0.99, layer=0): | |
super(VQEmbeddingEMA, self).__init__() | |
self.num_code = num_code | |
self.code_dim = code_dim | |
self.decay = decay | |
self.layer = layer | |
self.stale_tolerance = 100 | |
self.eps = torch.finfo(torch.float32).eps | |
embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) | |
self.register_buffer("embedding", embedding) | |
self.register_buffer("ema_weight", self.embedding.clone()) | |
self.register_buffer("ema_count", torch.zeros(self.num_code)) | |
self.register_buffer("stale_counter", torch.zeros(self.num_code)) | |
def forward(self, input): | |
B, N, T = input.shape | |
assert N == self.code_dim | |
input_detach = input.detach().mT.contiguous().view(B*T, N) # B*T, dim | |
# distance | |
eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) # B*T, num_code | |
eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) # B*T, num_code | |
# best codes | |
indices = torch.argmin(eu_dis, dim=-1) # B*T | |
quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim | |
quantized = quantized.view(B, T, N).mT.contiguous() # B, N, T | |
# calculate perplexity | |
encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code | |
avg_probs = encodings.mean(0) # num_code | |
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() | |
indices = indices.view(B, T) | |
if self.training: | |
# EMA update for codebook | |
self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code | |
update_direction = encodings.T.mm(input_detach) # num_code, dim | |
self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim | |
# Laplace smoothing on the counters | |
# make sure the denominator will never be zero | |
n = torch.sum(self.ema_count, dim=-1, keepdim=True) # 1 | |
self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code | |
self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) | |
# calculate code usage | |
stale_codes = (encodings.sum(0) == 0).float() # num_code | |
self.stale_counter = self.stale_counter * stale_codes + stale_codes | |
# random replace codes that haven't been used for a while | |
replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code | |
if replace_code.sum(-1).max() > 0: | |
random_input_idx = torch.randperm(input_detach.shape[0]) | |
random_input = input_detach[random_input_idx].view(input_detach.shape) | |
if random_input.shape[0] < self.num_code: | |
random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) | |
random_input = random_input[:self.num_code].contiguous() # num_code, dim | |
self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) | |
self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) | |
self.ema_count = self.ema_count * (1 - replace_code) | |
self.stale_counter = self.stale_counter * (1 - replace_code) | |
return quantized, indices, perplexity | |
class RVQEmbedding(nn.Module): | |
def __init__(self, code_dim, decay=0.99, bit=[10]): | |
super(RVQEmbedding, self).__init__() | |
self.code_dim = code_dim | |
self.decay = decay | |
self.eps = torch.finfo(torch.float32).eps | |
self.VQEmbedding = nn.ModuleList([]) | |
for i in range(len(bit)): | |
self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) | |
def forward(self, input): | |
quantized = [] | |
indices = [] | |
ppl = [] | |
residual_input = input | |
for i in range(len(self.VQEmbedding)): | |
this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) | |
indices.append(this_indices) | |
ppl.append(this_perplexity) | |
residual_input = residual_input - this_quantized | |
if i == 0: | |
quantized.append(this_quantized) | |
else: | |
quantized.append(quantized[-1] + this_quantized) | |
quantized = torch.stack(quantized, -1) | |
indices = torch.stack(indices, -1) | |
ppl = torch.stack(ppl, -1) | |
latent_loss = 0 | |
for i in range(quantized.shape[-1]): | |
latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) | |
return quantized, indices, ppl, latent_loss | |
class Codec(nn.Module): | |
def __init__(self, nch=1, sr=44100, win=80, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=False): | |
super(Codec, self).__init__() | |
self.nch = nch | |
self.sr = sr | |
self.win = int(sr / 1000 * win) | |
self.stride = self.win // 2 | |
self.enc_dim = self.win // 2 + 1 | |
self.feature_dim = feature_dim | |
self.vae_dim = vae_dim | |
self.bit = bit | |
self.eps = torch.finfo(torch.float32).eps | |
# 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-12k (1k hop), 12k-20k (2k hop), 20k-inf | |
# 55 bands | |
bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) | |
bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) | |
bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim)) | |
bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) | |
bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim)) | |
bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim)) | |
self.band_width = [bandwidth_50]*20 | |
self.band_width += [bandwidth_100]*10 | |
self.band_width += [bandwidth_250]*8 | |
self.band_width += [bandwidth_500]*8 | |
self.band_width += [bandwidth_1k]*4 | |
self.band_width += [bandwidth_2k]*4 | |
self.band_width.append(self.enc_dim - np.sum(self.band_width)) | |
self.nband = len(self.band_width) | |
print(self.band_width, self.nband) | |
self.VAE_BN = nn.ModuleList([]) | |
for i in range(self.nband): | |
self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), | |
nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) | |
) | |
self.VAE_encoder = [] | |
for _ in range(enc_layer): | |
self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) | |
self.VAE_encoder = nn.Sequential(*self.VAE_encoder) | |
self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), | |
nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) | |
) | |
self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) | |
self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) | |
self.VAE_decoder = [] | |
for _ in range(dec_layer): | |
self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) | |
self.VAE_decoder = nn.Sequential(*self.VAE_decoder) | |
self.VAE_output = nn.ModuleList([]) | |
for i in range(self.nband): | |
self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), | |
nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), | |
nn.GLU(dim=1)) | |
) | |
def spec_band_split(self, input): | |
B, nch, nsample = input.shape | |
spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(input.device), return_complex=True) | |
subband_spec = [] | |
subband_spec_norm = [] | |
subband_power = [] | |
band_idx = 0 | |
for i in range(self.nband): | |
this_spec = spec[:,band_idx:band_idx+self.band_width[i]] | |
subband_spec.append(this_spec) # B, BW, T | |
subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T | |
subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) # B, BW, T | |
band_idx += self.band_width[i] | |
subband_power = torch.cat(subband_power, 1) # B, nband, T | |
return subband_spec, subband_spec_norm, subband_power | |
def feature_extractor(self, input): | |
_, subband_spec_norm, subband_power = self.spec_band_split(input) | |
# normalization and bottleneck | |
subband_feature = [] | |
for i in range(self.nband): | |
concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) | |
concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) | |
subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) | |
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T | |
return subband_feature | |
def vae_sample(self, input): | |
B, nch, _ = input.shape | |
subband_feature = self.feature_extractor(input) | |
# encode | |
enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) | |
enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) | |
mu = enc_output[:,:,0].contiguous() | |
logvar = enc_output[:,:,1].contiguous() | |
# vae | |
reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) | |
return reparam_feature.view(B, nch, self.nband, self.vae_dim, -1) | |
def vae_decode(self, vae_feature): | |
B = vae_feature.shape[0] | |
dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) | |
output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) | |
est_spec = [] | |
for i in range(self.nband): | |
this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) | |
est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) | |
est_spec = torch.cat(est_spec, 1) | |
output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) | |
return output.type(vae_feature.type()) | |
def forward(self, input): | |
B, nch, nsample = input.shape | |
assert nch == self.nch | |
vae_feature = self.vae_sample(input) | |
output = self.vae_decode(vae_feature).view(B, nch, -1) | |
if(output.shape[-1] > nsample): | |
output = output[:,:,0:nsample] | |
elif(output.shape[-1] < nsample): | |
output = torch.cat([output, torch.zeros(B, nch, nsample - output.shape[-1], device= output.device, dtype=output.dtype)],-1) | |
return output | |
def encode(self, input, do_sample=True): | |
assert do_sample, do_sample | |
B, nch, nsample = input.shape | |
assert nch == self.nch | |
vae_feature = self.vae_sample(input) | |
return vae_feature | |
def get_bsrnnvae(ckpt): | |
nch = 1 | |
model = Codec(nch = nch, \ | |
win = 100, \ | |
feature_dim = 128, \ | |
vae_dim = 2, \ | |
bit = [14]*5, \ | |
causal = True) | |
weight = torch.load(ckpt, map_location='cpu') | |
model.load_state_dict(weight) | |
return model.eval() | |
if __name__ == '__main__': | |
model = Codec(causal=True) | |
x = torch.empty(1, 1, 44100).uniform_(-1, 1) | |
s = 0 | |
for param in model.parameters(): | |
s += np.product(param.size()) | |
print('# of parameters: '+str(s/1e6)+" M") | |
output = model(x) | |
print(output.shape) | |
macs, params = profile(model, inputs=(x,)) | |
macs, params = clever_format([macs, params], "%.3f") | |
print(macs, params) | |
import torchaudio | |
model = get_bsrnnvae() | |
inp, fs = torchaudio.load('769000.mp3') | |
inp = inp[[0],:] | |
if(fs!=44100): | |
inp = torchaudio.functional.resample(inp, fs, 44100) | |
fs = 44100 | |
inp = inp[:,0:30*44100] | |
out = model(inp[None,:,:]).detach() | |
torchaudio.save('out.flac', out[0], fs) |