hainazhu
Add application file
258fd02
raw
history blame
53 kB
import soundfile as sf
import os
from librosa.filters import mel as librosa_mel_fn
import sys
import tools.torch_tools as torch_tools
import torch.nn as nn
import torch
import numpy as np
from einops import rearrange
from scipy.signal import get_window
from librosa.util import pad_center, tiny
import librosa.util as librosa_util
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
LRELU_SLOPE = 0.1
class ResBlock(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock, self).__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
torch.nn.utils.remove_weight_norm(l)
for l in self.convs2:
torch.nn.utils.remove_weight_norm(l)
class Generator_old(torch.nn.Module):
def __init__(self, h):
super(Generator_old, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = torch.nn.utils.weight_norm(
nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
torch.nn.utils.weight_norm(
nn.ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = torch.nn.functional.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
# print("Removing weight norm...")
for l in self.ups:
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
torch.nn.utils.remove_weight_norm(self.conv_pre)
torch.nn.utils.remove_weight_norm(self.conv_post)
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class DownsampleTimeStride4(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
return x
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class UpsampleTimeStride4(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=5, stride=1, padding=2
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).contiguous()
q = q.permute(0, 2, 1).contiguous() # b,hw,c
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w).contiguous()
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_
).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w).contiguous()
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise ValueError(attn_type)
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
downsample_time_stride4_levels=[],
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.downsample_time_stride4_levels = downsample_time_stride4_levels
if len(self.downsample_time_stride4_levels) > 0:
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
% str(self.num_resolutions)
)
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level in self.downsample_time_stride4_levels:
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
else:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
downsample_time_stride4_levels=[],
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.downsample_time_stride4_levels = downsample_time_stride4_levels
if len(self.downsample_time_stride4_levels) > 0:
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
% str(self.num_resolutions)
)
# compute in_ch_mult, block_in and curr_res at lowest res
(1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level - 1 in self.downsample_time_stride4_levels:
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
else:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.mean(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def get_vocoder_config_48k():
return {
"resblock": "1",
"num_gpus": 8,
"batch_size": 128,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [6,5,4,2,2],
"upsample_kernel_sizes": [12,10,8,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11,15],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
"segment_size": 15360,
"num_mels": 256,
"n_fft": 2048,
"hop_size": 480,
"win_size": 2048,
"sampling_rate": 48000,
"fmin": 20,
"fmax": 24000,
"fmax_for_loss": None,
"num_workers": 8,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:18273",
"world_size": 1
}
}
def get_vocoder(config, device, mel_bins):
name = "HiFi-GAN"
speaker = ""
if name == "MelGAN":
if speaker == "LJSpeech":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
)
elif speaker == "universal":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
)
vocoder.mel2wav.eval()
vocoder.mel2wav.to(device)
elif name == "HiFi-GAN":
if(mel_bins == 256):
config = get_vocoder_config_48k()
config = AttrDict(config)
vocoder = Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder = vocoder.to(device)
# vocoder = vocoder.half()
else:
raise ValueError(mel_bins)
return vocoder
def vocoder_infer(mels, vocoder, lengths=None):
with torch.no_grad():
wavs = vocoder(mels).squeeze(1)
#wavs = (wavs.cpu().numpy() * 32768).astype("int16")
wavs = (wavs.cpu().numpy())
if lengths is not None:
wavs = wavs[:, :lengths]
# wavs = [wav for wav in wavs]
# for i in range(len(mels)):
# if lengths is not None:
# wavs[i] = wavs[i][: lengths[i]]
return wavs
@torch.no_grad()
def vocoder_chunk_infer(mels, vocoder, lengths=None):
chunk_size = 256*4
shift_size = 256*1
ov_size = chunk_size-shift_size
# import pdb;pdb.set_trace()
for cinx in range(0, mels.shape[2], shift_size):
if(cinx==0):
wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float()
num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size
wavs = wavs[:,0:num_samples]
ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size)
ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0)
ov_win = torch.cat([ov_win,1-ov_win],-1)
if(cinx+chunk_size>=mels.shape[2]):
break
else:
cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float()
wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample]
# wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0
wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1)
if(cinx+chunk_size>=mels.shape[2]):
break
# print(wavs.shape)
wavs = (wavs.cpu().numpy())
if lengths is not None:
wavs = wavs[:, :lengths]
# print(wavs.shape)
return wavs
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
if vocoder is not None:
wav_reconstruction = vocoder_infer(
mel_input.permute(0, 2, 1),
vocoder,
)
wav_prediction = vocoder_infer(
mel_prediction.permute(0, 2, 1),
vocoder,
)
else:
wav_reconstruction = wav_prediction = None
return wav_reconstruction, wav_prediction
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig=None,
lossconfig=None,
batchsize=None,
embed_dim=None,
time_shuffle=1,
subband=1,
sampling_rate=16000,
ckpt_path=None,
reload_from_ckpt=None,
ignore_keys=[],
image_key="fbank",
colorize_nlabels=None,
monitor=None,
base_learning_rate=1e-5,
scale_factor=1
):
super().__init__()
self.automatic_optimization = False
assert (
"mel_bins" in ddconfig.keys()
), "mel_bins is not specified in the Autoencoder config"
num_mel = ddconfig["mel_bins"]
self.image_key = image_key
self.sampling_rate = sampling_rate
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = None
self.subband = int(subband)
if self.subband > 1:
print("Use subband decomposition %s" % self.subband)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if self.image_key == "fbank":
self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.learning_rate = float(base_learning_rate)
# print("Initial learning rate %s" % self.learning_rate)
self.time_shuffle = time_shuffle
self.reload_from_ckpt = reload_from_ckpt
self.reloaded = False
self.mean, self.std = None, None
self.feature_cache = None
self.flag_first_run = True
self.train_step = 0
self.logger_save_dir = None
self.logger_exp_name = None
self.scale_factor = scale_factor
print("Num parameters:")
print("Encoder : ", sum(p.numel() for p in self.encoder.parameters()))
print("Decoder : ", sum(p.numel() for p in self.decoder.parameters()))
print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters()))
def get_log_dir(self):
if self.logger_save_dir is None and self.logger_exp_name is None:
return os.path.join(self.logger.save_dir, self.logger._project)
else:
return os.path.join(self.logger_save_dir, self.logger_exp_name)
def set_log_dir(self, save_dir, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_name = exp_name
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
# x = self.time_shuffle_operation(x)
# x = self.freq_split_subband(x)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
# bs, ch, shuffled_timesteps, fbins = dec.size()
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
# dec = self.freq_merge_subband(dec)
return dec
def decode_to_waveform(self, dec):
if self.image_key == "fbank":
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder)
elif self.image_key == "stft":
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = self.wave_decoder(dec)
return wav_reconstruction
def mel_spectrogram_to_waveform(
self, mel, savepath=".", bs=None, name="outwav", save=True
):
# Mel: [bs, 1, t-steps, fbins]
if len(mel.size()) == 4:
mel = mel.squeeze(1)
mel = mel.permute(0, 2, 1)
waveform = self.vocoder(mel)
waveform = waveform.cpu().detach().numpy()
#if save:
# self.save_waveform(waveform, savepath, name)
return waveform
@torch.no_grad()
def encode_first_stage(self, x):
return self.encode(x)
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, "b h w c -> b c h w").contiguous()
z = 1.0 / self.scale_factor * z
return self.decode(z)
def decode_first_stage_withgrad(self, z):
z = 1.0 / self.scale_factor * z
return self.decode(z)
def get_first_stage_encoding(self, encoder_posterior, use_mode=False):
if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode:
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode:
z = encoder_posterior.mode()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
def visualize_latent(self, input):
import matplotlib.pyplot as plt
# for i in range(10):
# zero_input = torch.zeros_like(input) - 11.59
# zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
# posterior = self.encode(zero_input)
# latent = posterior.sample()
# avg_latent = torch.mean(latent, dim=1)[0]
# plt.imshow(avg_latent.cpu().detach().numpy().T)
# plt.savefig("%s.png" % i)
# plt.close()
np.save("input.npy", input.cpu().detach().numpy())
# zero_input = torch.zeros_like(input) - 11.59
time_input = input.clone()
time_input[:, :, :, :32] *= 0
time_input[:, :, :, :32] -= 11.59
np.save("time_input.npy", time_input.cpu().detach().numpy())
posterior = self.encode(time_input)
latent = posterior.sample()
np.save("time_latent.npy", latent.cpu().detach().numpy())
avg_latent = torch.mean(latent, dim=1)
for i in range(avg_latent.size(0)):
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
plt.savefig("freq_%s.png" % i)
plt.close()
freq_input = input.clone()
freq_input[:, :, :512, :] *= 0
freq_input[:, :, :512, :] -= 11.59
np.save("freq_input.npy", freq_input.cpu().detach().numpy())
posterior = self.encode(freq_input)
latent = posterior.sample()
np.save("freq_latent.npy", latent.cpu().detach().numpy())
avg_latent = torch.mean(latent, dim=1)
for i in range(avg_latent.size(0)):
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
plt.savefig("time_%s.png" % i)
plt.close()
def get_input(self, batch):
fname, text, label_indices, waveform, stft, fbank = (
batch["fname"],
batch["text"],
batch["label_vector"],
batch["waveform"],
batch["stft"],
batch["log_mel_spec"],
)
# if(self.time_shuffle != 1):
# if(fbank.size(1) % self.time_shuffle != 0):
# pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
# fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
ret = {}
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
fbank.unsqueeze(1),
stft.unsqueeze(1),
fname,
waveform.unsqueeze(1),
)
return ret
def save_wave(self, batch_wav, fname, save_dir):
os.makedirs(save_dir, exist_ok=True)
for wav, name in zip(batch_wav, fname):
name = os.path.basename(name)
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
log = dict()
x = batch.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
log["samples"] = self.decode(posterior.sample())
log["reconstructions"] = xrec
log["inputs"] = x
wavs = self._log_img(log, train=train, index=0, waveform=waveform)
return wavs
def _log_img(self, log, train=True, index=0, waveform=None):
images_input = self.tensor2numpy(log["inputs"][index, 0]).T
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
images_samples = self.tensor2numpy(log["samples"][index, 0]).T
if train:
name = "train"
else:
name = "val"
if self.logger is not None:
self.logger.log_image(
"img_%s" % name,
[images_input, images_reconstruct, images_samples],
caption=["input", "reconstruct", "samples"],
)
inputs, reconstructions, samples = (
log["inputs"],
log["reconstructions"],
log["samples"],
)
if self.image_key == "fbank":
wav_original, wav_prediction = synth_one_sample(
inputs[index],
reconstructions[index],
labels="validation",
vocoder=self.vocoder,
)
wav_original, wav_samples = synth_one_sample(
inputs[index], samples[index], labels="validation", vocoder=self.vocoder
)
wav_original, wav_samples, wav_prediction = (
wav_original[0],
wav_samples[0],
wav_prediction[0],
)
elif self.image_key == "stft":
wav_prediction = (
self.decode_to_waveform(reconstructions)[index, 0]
.cpu()
.detach()
.numpy()
)
wav_samples = (
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
)
wav_original = waveform[index, 0].cpu().detach().numpy()
if self.logger is not None:
self.logger.experiment.log(
{
"original_%s"
% name: wandb.Audio(
wav_original, caption="original", sample_rate=self.sampling_rate
),
"reconstruct_%s"
% name: wandb.Audio(
wav_prediction,
caption="reconstruct",
sample_rate=self.sampling_rate,
),
"samples_%s"
% name: wandb.Audio(
wav_samples, caption="samples", sample_rate=self.sampling_rate
),
}
)
return wav_original, wav_prediction, wav_samples
def tensor2numpy(self, tensor):
return tensor.cpu().detach().numpy()
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = torch.nn.functional.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
def window_sumsquare(
window,
n_frames,
hop_length,
win_length,
n_fft,
dtype=np.float32,
norm=None,
):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
return x
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return normalize_fun(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length, hop_length, win_length, window="hann"):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
)
if window is not None:
assert filter_length >= win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
def transform(self, input_data):
device = self.forward_basis.device
input_data = input_data.to(device)
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = torch.nn.functional.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode="reflect",
)
input_data = input_data.squeeze(1)
forward_transform = torch.nn.functional.conv1d(
input_data,
torch.autograd.Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)#.cpu()
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
device = self.forward_basis.device
magnitude, phase = magnitude.to(device), phase.to(device)
recombine_magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = torch.nn.functional.conv_transpose1d(
recombine_magnitude_phase,
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
if self.window is not None:
window_sum = window_sumsquare(
self.window,
magnitude.size(-1),
hop_length=self.hop_length,
win_length=self.win_length,
n_fft=self.filter_length,
dtype=np.float32,
)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0]
)
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False
)
window_sum = window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class TacotronSTFT(torch.nn.Module):
def __init__(
self,
filter_length,
hop_length,
win_length,
n_mel_channels,
sampling_rate,
mel_fmin,
mel_fmax,
):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
def spectral_normalize(self, magnitudes, normalize_fun):
output = dynamic_range_compression(magnitudes, normalize_fun)
return output
def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output
def mel_spectrogram(self, y, normalize_fun=torch.log):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert torch.min(y.data) >= -1, torch.min(y.data)
assert torch.max(y.data) <= 1, torch.max(y.data)
magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output, normalize_fun)
energy = torch.norm(magnitudes, dim=1)
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
return mel_output, log_magnitudes, energy
def build_pretrained_models(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
print("scale_factor: ", scale_factor)
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
config = {
"preprocessing": {
"audio": {
"sampling_rate": 48000,
"max_wav_value": 32768,
"duration": 10.24
},
"stft": {
"filter_length": 2048,
"hop_length": 480,
"win_length": 2048
},
"mel": {
"n_mel_channels": 256,
"mel_fmin": 20,
"mel_fmax": 24000
}
},
"model": {
"params": {
"first_stage_config": {
"params": {
"sampling_rate": 48000,
"batchsize": 4,
"monitor": "val/rec_loss",
"image_key": "fbank",
"subband": 1,
"embed_dim": 16,
"time_shuffle": 1,
"lossconfig": {
"target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
"params": {
"disc_start": 50001,
"kl_weight": 1000,
"disc_weight": 0.5,
"disc_in_channels": 1
}
},
"ddconfig": {
"double_z": True,
"mel_bins": 256,
"z_channels": 16,
"resolution": 256,
"downsample_time": False,
"in_channels": 1,
"out_ch": 1,
"ch": 128,
"ch_mult": [
1,
2,
4,
8
],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0
}
}
},
}
}
}
vae_config = config["model"]["params"]["first_stage_config"]["params"]
vae_config["scale_factor"] = scale_factor
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(vae_state_dict)
fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
vae.eval()
fn_STFT.eval()
return vae, fn_STFT
if __name__=="__main__":
vae, stft = build_pretrained_models()
vae, stft = vae.cuda(), stft.cuda()
json_file="outputs/wav.scp"
out_path="outputs/Music_inverse"
wavform = torch.randn(2,int(48000*10.24))
mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft)
mel = mel.unsqueeze(1).cuda()
print(mel.shape)
# true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
# print(true_latent.shape)
true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel))
print(true_latent.shape)
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3])
print("111", true_latent.size())
mel = vae.decode_first_stage(true_latent)
print("222", mel.size())
audio = vae.decode_to_waveform(mel)
print("333", audio.shape)
# out_file = out_path + "/" + os.path.basename(fname.strip())
# sf.write(out_file, audio[0], samplerate=48000)