import soundfile as sf import os from librosa.filters import mel as librosa_mel_fn import sys import tools.torch_tools as torch_tools import torch.nn as nn import torch import numpy as np from einops import rearrange from scipy.signal import get_window from librosa.util import pad_center, tiny import librosa.util as librosa_util class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) LRELU_SLOPE = 0.1 class ResBlock(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock, self).__init__() self.h = h self.convs1 = nn.ModuleList( [ torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]), ) ), torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]), ) ), torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]), ) ), ] ) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList( [ torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ) ), torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ) ), torch.nn.utils.weight_norm( nn.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ) ), ] ) self.convs2.apply(init_weights) def forward(self, x): for c1, c2 in zip(self.convs1, self.convs2): xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: torch.nn.utils.remove_weight_norm(l) for l in self.convs2: torch.nn.utils.remove_weight_norm(l) class Generator_old(torch.nn.Module): def __init__(self, h): super(Generator_old, self).__init__() self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) self.conv_pre = torch.nn.utils.weight_norm( nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) ) resblock = ResBlock self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append( torch.nn.utils.weight_norm( nn.ConvTranspose1d( h.upsample_initial_channel // (2**i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2, ) ) ) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) ): self.resblocks.append(resblock(h, ch, k, d)) self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) def forward(self, x): x = self.conv_pre(x) for i in range(self.num_upsamples): x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = torch.nn.functional.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): # print("Removing weight norm...") for l in self.ups: torch.nn.utils.remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() torch.nn.utils.remove_weight_norm(self.conv_pre) torch.nn.utils.remove_weight_norm(self.conv_post) def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # Do time downsampling here # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class DownsampleTimeStride4(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # Do time downsampling here # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) return x class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class UpsampleTimeStride4(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=5, stride=1, padding=2 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") if self.with_conv: x = self.conv(x) return x class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w).contiguous() q = q.permute(0, 2, 1).contiguous() # b,hw,c k = k.reshape(b, c, h * w).contiguous() # b,c,hw w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w).contiguous() w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) h_ = torch.bmm( v, w_ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w).contiguous() h_ = self.proj_out(h_) return x + h_ def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "none": return nn.Identity(in_channels) else: raise ValueError(attn_type) class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", downsample_time_stride4_levels=[], **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.downsample_time_stride4_levels = downsample_time_stride4_levels if len(self.downsample_time_stride4_levels) > 0: assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" % str(self.num_resolutions) ) # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: if i_level in self.downsample_time_stride4_levels: down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) else: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, downsample_time_stride4_levels=[], attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out self.downsample_time_stride4_levels = downsample_time_stride4_levels if len(self.downsample_time_stride4_levels) > 0: assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" % str(self.num_resolutions) ) # compute in_ch_mult, block_in and curr_res at lowest res (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # print( # "Working with z of shape {} = {} dimensions.".format( # self.z_shape, np.prod(self.z_shape) # ) # ) # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: if i_level - 1 in self.downsample_time_stride4_levels: up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) else: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, z): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) if self.tanh_out: h = torch.tanh(h) return h class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to( device=self.parameters.device ) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.mean( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.mean( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean def get_vocoder_config_48k(): return { "resblock": "1", "num_gpus": 8, "batch_size": 128, "learning_rate": 0.0001, "adam_b1": 0.8, "adam_b2": 0.99, "lr_decay": 0.999, "seed": 1234, "upsample_rates": [6,5,4,2,2], "upsample_kernel_sizes": [12,10,8,4,4], "upsample_initial_channel": 1536, "resblock_kernel_sizes": [3,7,11,15], "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], "segment_size": 15360, "num_mels": 256, "n_fft": 2048, "hop_size": 480, "win_size": 2048, "sampling_rate": 48000, "fmin": 20, "fmax": 24000, "fmax_for_loss": None, "num_workers": 8, "dist_config": { "dist_backend": "nccl", "dist_url": "tcp://localhost:18273", "world_size": 1 } } def get_vocoder(config, device, mel_bins): name = "HiFi-GAN" speaker = "" if name == "MelGAN": if speaker == "LJSpeech": vocoder = torch.hub.load( "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" ) elif speaker == "universal": vocoder = torch.hub.load( "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" ) vocoder.mel2wav.eval() vocoder.mel2wav.to(device) elif name == "HiFi-GAN": if(mel_bins == 256): config = get_vocoder_config_48k() config = AttrDict(config) vocoder = Generator_old(config) # print("Load hifigan/g_01080000") # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) # ckpt = torch_version_orig_mod_remove(ckpt) # vocoder.load_state_dict(ckpt["generator"]) vocoder.eval() vocoder.remove_weight_norm() vocoder = vocoder.to(device) # vocoder = vocoder.half() else: raise ValueError(mel_bins) return vocoder def vocoder_infer(mels, vocoder, lengths=None): with torch.no_grad(): wavs = vocoder(mels).squeeze(1) #wavs = (wavs.cpu().numpy() * 32768).astype("int16") wavs = (wavs.cpu().numpy()) if lengths is not None: wavs = wavs[:, :lengths] # wavs = [wav for wav in wavs] # for i in range(len(mels)): # if lengths is not None: # wavs[i] = wavs[i][: lengths[i]] return wavs @torch.no_grad() def vocoder_chunk_infer(mels, vocoder, lengths=None): chunk_size = 256*4 shift_size = 256*1 ov_size = chunk_size-shift_size # import pdb;pdb.set_trace() for cinx in range(0, mels.shape[2], shift_size): if(cinx==0): wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size wavs = wavs[:,0:num_samples] ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) ov_win = torch.cat([ov_win,1-ov_win],-1) if(cinx+chunk_size>=mels.shape[2]): break else: cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0 wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) if(cinx+chunk_size>=mels.shape[2]): break # print(wavs.shape) wavs = (wavs.cpu().numpy()) if lengths is not None: wavs = wavs[:, :lengths] # print(wavs.shape) return wavs def synth_one_sample(mel_input, mel_prediction, labels, vocoder): if vocoder is not None: wav_reconstruction = vocoder_infer( mel_input.permute(0, 2, 1), vocoder, ) wav_prediction = vocoder_infer( mel_prediction.permute(0, 2, 1), vocoder, ) else: wav_reconstruction = wav_prediction = None return wav_reconstruction, wav_prediction class AutoencoderKL(nn.Module): def __init__( self, ddconfig=None, lossconfig=None, batchsize=None, embed_dim=None, time_shuffle=1, subband=1, sampling_rate=16000, ckpt_path=None, reload_from_ckpt=None, ignore_keys=[], image_key="fbank", colorize_nlabels=None, monitor=None, base_learning_rate=1e-5, scale_factor=1 ): super().__init__() self.automatic_optimization = False assert ( "mel_bins" in ddconfig.keys() ), "mel_bins is not specified in the Autoencoder config" num_mel = ddconfig["mel_bins"] self.image_key = image_key self.sampling_rate = sampling_rate self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = None self.subband = int(subband) if self.subband > 1: print("Use subband decomposition %s" % self.subband) assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if self.image_key == "fbank": self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.learning_rate = float(base_learning_rate) # print("Initial learning rate %s" % self.learning_rate) self.time_shuffle = time_shuffle self.reload_from_ckpt = reload_from_ckpt self.reloaded = False self.mean, self.std = None, None self.feature_cache = None self.flag_first_run = True self.train_step = 0 self.logger_save_dir = None self.logger_exp_name = None self.scale_factor = scale_factor print("Num parameters:") print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) def get_log_dir(self): if self.logger_save_dir is None and self.logger_exp_name is None: return os.path.join(self.logger.save_dir, self.logger._project) else: return os.path.join(self.logger_save_dir, self.logger_exp_name) def set_log_dir(self, save_dir, exp_name): self.logger_save_dir = save_dir self.logger_exp_name = exp_name def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") def encode(self, x): # x = self.time_shuffle_operation(x) # x = self.freq_split_subband(x) h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): z = self.post_quant_conv(z) dec = self.decoder(z) # bs, ch, shuffled_timesteps, fbins = dec.size() # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) # dec = self.freq_merge_subband(dec) return dec def decode_to_waveform(self, dec): if self.image_key == "fbank": dec = dec.squeeze(1).permute(0, 2, 1) wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) elif self.image_key == "stft": dec = dec.squeeze(1).permute(0, 2, 1) wav_reconstruction = self.wave_decoder(dec) return wav_reconstruction def mel_spectrogram_to_waveform( self, mel, savepath=".", bs=None, name="outwav", save=True ): # Mel: [bs, 1, t-steps, fbins] if len(mel.size()) == 4: mel = mel.squeeze(1) mel = mel.permute(0, 2, 1) waveform = self.vocoder(mel) waveform = waveform.cpu().detach().numpy() #if save: # self.save_waveform(waveform, savepath, name) return waveform @torch.no_grad() def encode_first_stage(self, x): return self.encode(x) @torch.no_grad() def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, "b h w c -> b c h w").contiguous() z = 1.0 / self.scale_factor * z return self.decode(z) def decode_first_stage_withgrad(self, z): z = 1.0 / self.scale_factor * z return self.decode(z) def get_first_stage_encoding(self, encoder_posterior, use_mode=False): if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: z = encoder_posterior.sample() elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: z = encoder_posterior.mode() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError( f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" ) return self.scale_factor * z def visualize_latent(self, input): import matplotlib.pyplot as plt # for i in range(10): # zero_input = torch.zeros_like(input) - 11.59 # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 # posterior = self.encode(zero_input) # latent = posterior.sample() # avg_latent = torch.mean(latent, dim=1)[0] # plt.imshow(avg_latent.cpu().detach().numpy().T) # plt.savefig("%s.png" % i) # plt.close() np.save("input.npy", input.cpu().detach().numpy()) # zero_input = torch.zeros_like(input) - 11.59 time_input = input.clone() time_input[:, :, :, :32] *= 0 time_input[:, :, :, :32] -= 11.59 np.save("time_input.npy", time_input.cpu().detach().numpy()) posterior = self.encode(time_input) latent = posterior.sample() np.save("time_latent.npy", latent.cpu().detach().numpy()) avg_latent = torch.mean(latent, dim=1) for i in range(avg_latent.size(0)): plt.imshow(avg_latent[i].cpu().detach().numpy().T) plt.savefig("freq_%s.png" % i) plt.close() freq_input = input.clone() freq_input[:, :, :512, :] *= 0 freq_input[:, :, :512, :] -= 11.59 np.save("freq_input.npy", freq_input.cpu().detach().numpy()) posterior = self.encode(freq_input) latent = posterior.sample() np.save("freq_latent.npy", latent.cpu().detach().numpy()) avg_latent = torch.mean(latent, dim=1) for i in range(avg_latent.size(0)): plt.imshow(avg_latent[i].cpu().detach().numpy().T) plt.savefig("time_%s.png" % i) plt.close() def get_input(self, batch): fname, text, label_indices, waveform, stft, fbank = ( batch["fname"], batch["text"], batch["label_vector"], batch["waveform"], batch["stft"], batch["log_mel_spec"], ) # if(self.time_shuffle != 1): # if(fbank.size(1) % self.time_shuffle != 0): # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) ret = {} ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( fbank.unsqueeze(1), stft.unsqueeze(1), fname, waveform.unsqueeze(1), ) return ret def save_wave(self, batch_wav, fname, save_dir): os.makedirs(save_dir, exist_ok=True) for wav, name in zip(batch_wav, fname): name = os.path.basename(name) sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) def get_last_layer(self): return self.decoder.conv_out.weight @torch.no_grad() def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): log = dict() x = batch.to(self.device) if not only_inputs: xrec, posterior = self(x) log["samples"] = self.decode(posterior.sample()) log["reconstructions"] = xrec log["inputs"] = x wavs = self._log_img(log, train=train, index=0, waveform=waveform) return wavs def _log_img(self, log, train=True, index=0, waveform=None): images_input = self.tensor2numpy(log["inputs"][index, 0]).T images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T images_samples = self.tensor2numpy(log["samples"][index, 0]).T if train: name = "train" else: name = "val" if self.logger is not None: self.logger.log_image( "img_%s" % name, [images_input, images_reconstruct, images_samples], caption=["input", "reconstruct", "samples"], ) inputs, reconstructions, samples = ( log["inputs"], log["reconstructions"], log["samples"], ) if self.image_key == "fbank": wav_original, wav_prediction = synth_one_sample( inputs[index], reconstructions[index], labels="validation", vocoder=self.vocoder, ) wav_original, wav_samples = synth_one_sample( inputs[index], samples[index], labels="validation", vocoder=self.vocoder ) wav_original, wav_samples, wav_prediction = ( wav_original[0], wav_samples[0], wav_prediction[0], ) elif self.image_key == "stft": wav_prediction = ( self.decode_to_waveform(reconstructions)[index, 0] .cpu() .detach() .numpy() ) wav_samples = ( self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() ) wav_original = waveform[index, 0].cpu().detach().numpy() if self.logger is not None: self.logger.experiment.log( { "original_%s" % name: wandb.Audio( wav_original, caption="original", sample_rate=self.sampling_rate ), "reconstruct_%s" % name: wandb.Audio( wav_prediction, caption="reconstruct", sample_rate=self.sampling_rate, ), "samples_%s" % name: wandb.Audio( wav_samples, caption="samples", sample_rate=self.sampling_rate ), } ) return wav_original, wav_prediction, wav_samples def tensor2numpy(self, tensor): return tensor.cpu().detach().numpy() def to_rgb(self, x): assert self.image_key == "segmentation" if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = torch.nn.functional.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class IdentityFirstStage(torch.nn.Module): def __init__(self, *args, vq_interface=False, **kwargs): self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff super().__init__() def encode(self, x, *args, **kwargs): return x def decode(self, x, *args, **kwargs): return x def quantize(self, x, *args, **kwargs): if self.vq_interface: return x, None, [None, None, None] return x def forward(self, x, *args, **kwargs): return x def window_sumsquare( window, n_frames, hop_length, win_length, n_fft, dtype=np.float32, norm=None, ): """ # from librosa 0.6 Compute the sum-square envelope of a window function at a given hop length. This is used to estimate modulation effects induced by windowing observations in short-time fourier transforms. Parameters ---------- window : string, tuple, number, callable, or list-like Window specification, as in `get_window` n_frames : int > 0 The number of analysis frames hop_length : int > 0 The number of samples to advance between frames win_length : [optional] The length of the window function. By default, this matches `n_fft`. n_fft : int > 0 The length of each analysis frame. dtype : np.dtype The data type of the output Returns ------- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` The sum-squared envelope of the window function """ if win_length is None: win_length = n_fft n = n_fft + hop_length * (n_frames - 1) x = np.zeros(n, dtype=dtype) # Compute the squared window at the desired length win_sq = get_window(window, win_length, fftbins=True) win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 win_sq = librosa_util.pad_center(win_sq, n_fft) # Fill the envelope for i in range(n_frames): sample = i * hop_length x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] return x def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): """ PARAMS ------ C: compression factor """ return normalize_fun(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression(x, C=1): """ PARAMS ------ C: compression factor used to compress """ return torch.exp(x) / C class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" def __init__(self, filter_length, hop_length, win_length, window="hann"): super(STFT, self).__init__() self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length self.window = window self.forward_transform = None scale = self.filter_length / self.hop_length fourier_basis = np.fft.fft(np.eye(self.filter_length)) cutoff = int((self.filter_length / 2 + 1)) fourier_basis = np.vstack( [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] ) forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) inverse_basis = torch.FloatTensor( np.linalg.pinv(scale * fourier_basis).T[:, None, :] ) if window is not None: assert filter_length >= win_length # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) fft_window = pad_center(fft_window, size=filter_length) fft_window = torch.from_numpy(fft_window).float() # window the bases forward_basis *= fft_window inverse_basis *= fft_window self.register_buffer("forward_basis", forward_basis.float()) self.register_buffer("inverse_basis", inverse_basis.float()) def transform(self, input_data): device = self.forward_basis.device input_data = input_data.to(device) num_batches = input_data.size(0) num_samples = input_data.size(1) self.num_samples = num_samples # similar to librosa, reflect-pad the input input_data = input_data.view(num_batches, 1, num_samples) input_data = torch.nn.functional.pad( input_data.unsqueeze(1), (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), mode="reflect", ) input_data = input_data.squeeze(1) forward_transform = torch.nn.functional.conv1d( input_data, torch.autograd.Variable(self.forward_basis, requires_grad=False), stride=self.hop_length, padding=0, )#.cpu() cutoff = int((self.filter_length / 2) + 1) real_part = forward_transform[:, :cutoff, :] imag_part = forward_transform[:, cutoff:, :] magnitude = torch.sqrt(real_part**2 + imag_part**2) phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) return magnitude, phase def inverse(self, magnitude, phase): device = self.forward_basis.device magnitude, phase = magnitude.to(device), phase.to(device) recombine_magnitude_phase = torch.cat( [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 ) inverse_transform = torch.nn.functional.conv_transpose1d( recombine_magnitude_phase, torch.autograd.Variable(self.inverse_basis, requires_grad=False), stride=self.hop_length, padding=0, ) if self.window is not None: window_sum = window_sumsquare( self.window, magnitude.size(-1), hop_length=self.hop_length, win_length=self.win_length, n_fft=self.filter_length, dtype=np.float32, ) # remove modulation effects approx_nonzero_indices = torch.from_numpy( np.where(window_sum > tiny(window_sum))[0] ) window_sum = torch.autograd.Variable( torch.from_numpy(window_sum), requires_grad=False ) window_sum = window_sum inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ approx_nonzero_indices ] # scale by hop ratio inverse_transform *= float(self.filter_length) / self.hop_length inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] return inverse_transform def forward(self, input_data): self.magnitude, self.phase = self.transform(input_data) reconstruction = self.inverse(self.magnitude, self.phase) return reconstruction class TacotronSTFT(torch.nn.Module): def __init__( self, filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax, ): super(TacotronSTFT, self).__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.stft_fn = STFT(filter_length, hop_length, win_length) mel_basis = librosa_mel_fn( sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax ) mel_basis = torch.from_numpy(mel_basis).float() self.register_buffer("mel_basis", mel_basis) def spectral_normalize(self, magnitudes, normalize_fun): output = dynamic_range_compression(magnitudes, normalize_fun) return output def spectral_de_normalize(self, magnitudes): output = dynamic_range_decompression(magnitudes) return output def mel_spectrogram(self, y, normalize_fun=torch.log): """Computes mel-spectrograms from a batch of waves PARAMS ------ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] RETURNS ------- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) """ assert torch.min(y.data) >= -1, torch.min(y.data) assert torch.max(y.data) <= 1, torch.max(y.data) magnitudes, phases = self.stft_fn.transform(y) magnitudes = magnitudes.data mel_output = torch.matmul(self.mel_basis, magnitudes) mel_output = self.spectral_normalize(mel_output, normalize_fun) energy = torch.norm(magnitudes, dim=1) log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) return mel_output, log_magnitudes, energy def build_pretrained_models(ckpt): checkpoint = torch.load(ckpt, map_location="cpu") scale_factor = checkpoint["state_dict"]["scale_factor"].item() print("scale_factor: ", scale_factor) vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} config = { "preprocessing": { "audio": { "sampling_rate": 48000, "max_wav_value": 32768, "duration": 10.24 }, "stft": { "filter_length": 2048, "hop_length": 480, "win_length": 2048 }, "mel": { "n_mel_channels": 256, "mel_fmin": 20, "mel_fmax": 24000 } }, "model": { "params": { "first_stage_config": { "params": { "sampling_rate": 48000, "batchsize": 4, "monitor": "val/rec_loss", "image_key": "fbank", "subband": 1, "embed_dim": 16, "time_shuffle": 1, "lossconfig": { "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", "params": { "disc_start": 50001, "kl_weight": 1000, "disc_weight": 0.5, "disc_in_channels": 1 } }, "ddconfig": { "double_z": True, "mel_bins": 256, "z_channels": 16, "resolution": 256, "downsample_time": False, "in_channels": 1, "out_ch": 1, "ch": 128, "ch_mult": [ 1, 2, 4, 8 ], "num_res_blocks": 2, "attn_resolutions": [], "dropout": 0 } } }, } } } vae_config = config["model"]["params"]["first_stage_config"]["params"] vae_config["scale_factor"] = scale_factor vae = AutoencoderKL(**vae_config) vae.load_state_dict(vae_state_dict) fn_STFT = TacotronSTFT( config["preprocessing"]["stft"]["filter_length"], config["preprocessing"]["stft"]["hop_length"], config["preprocessing"]["stft"]["win_length"], config["preprocessing"]["mel"]["n_mel_channels"], config["preprocessing"]["audio"]["sampling_rate"], config["preprocessing"]["mel"]["mel_fmin"], config["preprocessing"]["mel"]["mel_fmax"], ) vae.eval() fn_STFT.eval() return vae, fn_STFT if __name__=="__main__": vae, stft = build_pretrained_models() vae, stft = vae.cuda(), stft.cuda() json_file="outputs/wav.scp" out_path="outputs/Music_inverse" wavform = torch.randn(2,int(48000*10.24)) mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) mel = mel.unsqueeze(1).cuda() print(mel.shape) # true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0) # print(true_latent.shape) true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) print(true_latent.shape) true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) print("111", true_latent.size()) mel = vae.decode_first_stage(true_latent) print("222", mel.size()) audio = vae.decode_to_waveform(mel) print("333", audio.shape) # out_file = out_path + "/" + os.path.basename(fname.strip()) # sf.write(out_file, audio[0], samplerate=48000)