Spaces:
Running
on
L40S
Running
on
L40S
import soundfile as sf | |
import os | |
from librosa.filters import mel as librosa_mel_fn | |
import sys | |
import tools.torch_tools as torch_tools | |
import torch.nn as nn | |
import torch | |
import numpy as np | |
from einops import rearrange | |
from scipy.signal import get_window | |
from librosa.util import pad_center, tiny | |
import librosa.util as librosa_util | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
def init_weights(m, mean=0.0, std=0.01): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(mean, std) | |
def get_padding(kernel_size, dilation=1): | |
return int((kernel_size * dilation - dilation) / 2) | |
LRELU_SLOPE = 0.1 | |
class ResBlock(torch.nn.Module): | |
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): | |
super(ResBlock, self).__init__() | |
self.h = h | |
self.convs1 = nn.ModuleList( | |
[ | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[0], | |
padding=get_padding(kernel_size, dilation[0]), | |
) | |
), | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[1], | |
padding=get_padding(kernel_size, dilation[1]), | |
) | |
), | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[2], | |
padding=get_padding(kernel_size, dilation[2]), | |
) | |
), | |
] | |
) | |
self.convs1.apply(init_weights) | |
self.convs2 = nn.ModuleList( | |
[ | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
torch.nn.utils.weight_norm( | |
nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
] | |
) | |
self.convs2.apply(init_weights) | |
def forward(self, x): | |
for c1, c2 in zip(self.convs1, self.convs2): | |
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
xt = c1(xt) | |
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) | |
xt = c2(xt) | |
x = xt + x | |
return x | |
def remove_weight_norm(self): | |
for l in self.convs1: | |
torch.nn.utils.remove_weight_norm(l) | |
for l in self.convs2: | |
torch.nn.utils.remove_weight_norm(l) | |
class Generator_old(torch.nn.Module): | |
def __init__(self, h): | |
super(Generator_old, self).__init__() | |
self.h = h | |
self.num_kernels = len(h.resblock_kernel_sizes) | |
self.num_upsamples = len(h.upsample_rates) | |
self.conv_pre = torch.nn.utils.weight_norm( | |
nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) | |
) | |
resblock = ResBlock | |
self.ups = nn.ModuleList() | |
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): | |
self.ups.append( | |
torch.nn.utils.weight_norm( | |
nn.ConvTranspose1d( | |
h.upsample_initial_channel // (2**i), | |
h.upsample_initial_channel // (2 ** (i + 1)), | |
k, | |
u, | |
padding=(k - u) // 2, | |
) | |
) | |
) | |
self.resblocks = nn.ModuleList() | |
for i in range(len(self.ups)): | |
ch = h.upsample_initial_channel // (2 ** (i + 1)) | |
for j, (k, d) in enumerate( | |
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) | |
): | |
self.resblocks.append(resblock(h, ch, k, d)) | |
self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) | |
self.ups.apply(init_weights) | |
self.conv_post.apply(init_weights) | |
def forward(self, x): | |
x = self.conv_pre(x) | |
for i in range(self.num_upsamples): | |
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
x = self.ups[i](x) | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
x = torch.nn.functional.leaky_relu(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |
def remove_weight_norm(self): | |
# print("Removing weight norm...") | |
for l in self.ups: | |
torch.nn.utils.remove_weight_norm(l) | |
for l in self.resblocks: | |
l.remove_weight_norm() | |
torch.nn.utils.remove_weight_norm(self.conv_pre) | |
torch.nn.utils.remove_weight_norm(self.conv_post) | |
def nonlinearity(x): | |
# swish | |
return x * torch.sigmoid(x) | |
def Normalize(in_channels, num_groups=32): | |
return torch.nn.GroupNorm( | |
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True | |
) | |
class Downsample(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
# Do time downsampling here | |
# no asymmetric padding in torch conv, must do it ourselves | |
self.conv = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=3, stride=2, padding=0 | |
) | |
def forward(self, x): | |
if self.with_conv: | |
pad = (0, 1, 0, 1) | |
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
x = self.conv(x) | |
else: | |
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) | |
return x | |
class DownsampleTimeStride4(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
# Do time downsampling here | |
# no asymmetric padding in torch conv, must do it ourselves | |
self.conv = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 | |
) | |
def forward(self, x): | |
if self.with_conv: | |
pad = (0, 1, 0, 1) | |
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
x = self.conv(x) | |
else: | |
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) | |
return x | |
class Upsample(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=3, stride=1, padding=1 | |
) | |
def forward(self, x): | |
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
if self.with_conv: | |
x = self.conv(x) | |
return x | |
class UpsampleTimeStride4(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=5, stride=1, padding=2 | |
) | |
def forward(self, x): | |
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") | |
if self.with_conv: | |
x = self.conv(x) | |
return x | |
class AttnBlock(nn.Module): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.in_channels = in_channels | |
self.norm = Normalize(in_channels) | |
self.q = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.k = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.v = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.proj_out = torch.nn.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
def forward(self, x): | |
h_ = x | |
h_ = self.norm(h_) | |
q = self.q(h_) | |
k = self.k(h_) | |
v = self.v(h_) | |
# compute attention | |
b, c, h, w = q.shape | |
q = q.reshape(b, c, h * w).contiguous() | |
q = q.permute(0, 2, 1).contiguous() # b,hw,c | |
k = k.reshape(b, c, h * w).contiguous() # b,c,hw | |
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
w_ = w_ * (int(c) ** (-0.5)) | |
w_ = torch.nn.functional.softmax(w_, dim=2) | |
# attend to values | |
v = v.reshape(b, c, h * w).contiguous() | |
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) | |
h_ = torch.bmm( | |
v, w_ | |
).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
h_ = h_.reshape(b, c, h, w).contiguous() | |
h_ = self.proj_out(h_) | |
return x + h_ | |
def make_attn(in_channels, attn_type="vanilla"): | |
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" | |
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels") | |
if attn_type == "vanilla": | |
return AttnBlock(in_channels) | |
elif attn_type == "none": | |
return nn.Identity(in_channels) | |
else: | |
raise ValueError(attn_type) | |
class ResnetBlock(nn.Module): | |
def __init__( | |
self, | |
*, | |
in_channels, | |
out_channels=None, | |
conv_shortcut=False, | |
dropout, | |
temb_channels=512, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.norm1 = Normalize(in_channels) | |
self.conv1 = torch.nn.Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
) | |
if temb_channels > 0: | |
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) | |
self.norm2 = Normalize(out_channels) | |
self.dropout = torch.nn.Dropout(dropout) | |
self.conv2 = torch.nn.Conv2d( | |
out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
) | |
if self.in_channels != self.out_channels: | |
if self.use_conv_shortcut: | |
self.conv_shortcut = torch.nn.Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
) | |
else: | |
self.nin_shortcut = torch.nn.Conv2d( | |
in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
) | |
def forward(self, x, temb): | |
h = x | |
h = self.norm1(h) | |
h = nonlinearity(h) | |
h = self.conv1(h) | |
if temb is not None: | |
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] | |
h = self.norm2(h) | |
h = nonlinearity(h) | |
h = self.dropout(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
if self.use_conv_shortcut: | |
x = self.conv_shortcut(x) | |
else: | |
x = self.nin_shortcut(x) | |
return x + h | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
*, | |
ch, | |
out_ch, | |
ch_mult=(1, 2, 4, 8), | |
num_res_blocks, | |
attn_resolutions, | |
dropout=0.0, | |
resamp_with_conv=True, | |
in_channels, | |
resolution, | |
z_channels, | |
double_z=True, | |
use_linear_attn=False, | |
attn_type="vanilla", | |
downsample_time_stride4_levels=[], | |
**ignore_kwargs, | |
): | |
super().__init__() | |
if use_linear_attn: | |
attn_type = "linear" | |
self.ch = ch | |
self.temb_ch = 0 | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.in_channels = in_channels | |
self.downsample_time_stride4_levels = downsample_time_stride4_levels | |
if len(self.downsample_time_stride4_levels) > 0: | |
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( | |
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s" | |
% str(self.num_resolutions) | |
) | |
# downsampling | |
self.conv_in = torch.nn.Conv2d( | |
in_channels, self.ch, kernel_size=3, stride=1, padding=1 | |
) | |
curr_res = resolution | |
in_ch_mult = (1,) + tuple(ch_mult) | |
self.in_ch_mult = in_ch_mult | |
self.down = nn.ModuleList() | |
for i_level in range(self.num_resolutions): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_in = ch * in_ch_mult[i_level] | |
block_out = ch * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks): | |
block.append( | |
ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_out, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
if curr_res in attn_resolutions: | |
attn.append(make_attn(block_in, attn_type=attn_type)) | |
down = nn.Module() | |
down.block = block | |
down.attn = attn | |
if i_level != self.num_resolutions - 1: | |
if i_level in self.downsample_time_stride4_levels: | |
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) | |
else: | |
down.downsample = Downsample(block_in, resamp_with_conv) | |
curr_res = curr_res // 2 | |
self.down.append(down) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
self.mid.block_2 = ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
# end | |
self.norm_out = Normalize(block_in) | |
self.conv_out = torch.nn.Conv2d( | |
block_in, | |
2 * z_channels if double_z else z_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
def forward(self, x): | |
# timestep embedding | |
temb = None | |
# downsampling | |
hs = [self.conv_in(x)] | |
for i_level in range(self.num_resolutions): | |
for i_block in range(self.num_res_blocks): | |
h = self.down[i_level].block[i_block](hs[-1], temb) | |
if len(self.down[i_level].attn) > 0: | |
h = self.down[i_level].attn[i_block](h) | |
hs.append(h) | |
if i_level != self.num_resolutions - 1: | |
hs.append(self.down[i_level].downsample(hs[-1])) | |
# middle | |
h = hs[-1] | |
h = self.mid.block_1(h, temb) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h, temb) | |
# end | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h) | |
return h | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
*, | |
ch, | |
out_ch, | |
ch_mult=(1, 2, 4, 8), | |
num_res_blocks, | |
attn_resolutions, | |
dropout=0.0, | |
resamp_with_conv=True, | |
in_channels, | |
resolution, | |
z_channels, | |
give_pre_end=False, | |
tanh_out=False, | |
use_linear_attn=False, | |
downsample_time_stride4_levels=[], | |
attn_type="vanilla", | |
**ignorekwargs, | |
): | |
super().__init__() | |
if use_linear_attn: | |
attn_type = "linear" | |
self.ch = ch | |
self.temb_ch = 0 | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.in_channels = in_channels | |
self.give_pre_end = give_pre_end | |
self.tanh_out = tanh_out | |
self.downsample_time_stride4_levels = downsample_time_stride4_levels | |
if len(self.downsample_time_stride4_levels) > 0: | |
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( | |
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s" | |
% str(self.num_resolutions) | |
) | |
# compute in_ch_mult, block_in and curr_res at lowest res | |
(1,) + tuple(ch_mult) | |
block_in = ch * ch_mult[self.num_resolutions - 1] | |
curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
self.z_shape = (1, z_channels, curr_res, curr_res) | |
# print( | |
# "Working with z of shape {} = {} dimensions.".format( | |
# self.z_shape, np.prod(self.z_shape) | |
# ) | |
# ) | |
# z to block_in | |
self.conv_in = torch.nn.Conv2d( | |
z_channels, block_in, kernel_size=3, stride=1, padding=1 | |
) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
self.mid.block_2 = ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
# upsampling | |
self.up = nn.ModuleList() | |
for i_level in reversed(range(self.num_resolutions)): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_out = ch * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks + 1): | |
block.append( | |
ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_out, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
if curr_res in attn_resolutions: | |
attn.append(make_attn(block_in, attn_type=attn_type)) | |
up = nn.Module() | |
up.block = block | |
up.attn = attn | |
if i_level != 0: | |
if i_level - 1 in self.downsample_time_stride4_levels: | |
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) | |
else: | |
up.upsample = Upsample(block_in, resamp_with_conv) | |
curr_res = curr_res * 2 | |
self.up.insert(0, up) # prepend to get consistent order | |
# end | |
self.norm_out = Normalize(block_in) | |
self.conv_out = torch.nn.Conv2d( | |
block_in, out_ch, kernel_size=3, stride=1, padding=1 | |
) | |
def forward(self, z): | |
# assert z.shape[1:] == self.z_shape[1:] | |
self.last_z_shape = z.shape | |
# timestep embedding | |
temb = None | |
# z to block_in | |
h = self.conv_in(z) | |
# middle | |
h = self.mid.block_1(h, temb) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h, temb) | |
# upsampling | |
for i_level in reversed(range(self.num_resolutions)): | |
for i_block in range(self.num_res_blocks + 1): | |
h = self.up[i_level].block[i_block](h, temb) | |
if len(self.up[i_level].attn) > 0: | |
h = self.up[i_level].attn[i_block](h) | |
if i_level != 0: | |
h = self.up[i_level].upsample(h) | |
# end | |
if self.give_pre_end: | |
return h | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h) | |
if self.tanh_out: | |
h = torch.tanh(h) | |
return h | |
class DiagonalGaussianDistribution(object): | |
def __init__(self, parameters, deterministic=False): | |
self.parameters = parameters | |
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = torch.exp(0.5 * self.logvar) | |
self.var = torch.exp(self.logvar) | |
if self.deterministic: | |
self.var = self.std = torch.zeros_like(self.mean).to( | |
device=self.parameters.device | |
) | |
def sample(self): | |
x = self.mean + self.std * torch.randn(self.mean.shape).to( | |
device=self.parameters.device | |
) | |
return x | |
def kl(self, other=None): | |
if self.deterministic: | |
return torch.Tensor([0.0]) | |
else: | |
if other is None: | |
return 0.5 * torch.mean( | |
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, | |
dim=[1, 2, 3], | |
) | |
else: | |
return 0.5 * torch.mean( | |
torch.pow(self.mean - other.mean, 2) / other.var | |
+ self.var / other.var | |
- 1.0 | |
- self.logvar | |
+ other.logvar, | |
dim=[1, 2, 3], | |
) | |
def nll(self, sample, dims=[1, 2, 3]): | |
if self.deterministic: | |
return torch.Tensor([0.0]) | |
logtwopi = np.log(2.0 * np.pi) | |
return 0.5 * torch.sum( | |
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
dim=dims, | |
) | |
def mode(self): | |
return self.mean | |
def get_vocoder_config_48k(): | |
return { | |
"resblock": "1", | |
"num_gpus": 8, | |
"batch_size": 128, | |
"learning_rate": 0.0001, | |
"adam_b1": 0.8, | |
"adam_b2": 0.99, | |
"lr_decay": 0.999, | |
"seed": 1234, | |
"upsample_rates": [6,5,4,2,2], | |
"upsample_kernel_sizes": [12,10,8,4,4], | |
"upsample_initial_channel": 1536, | |
"resblock_kernel_sizes": [3,7,11,15], | |
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], | |
"segment_size": 15360, | |
"num_mels": 256, | |
"n_fft": 2048, | |
"hop_size": 480, | |
"win_size": 2048, | |
"sampling_rate": 48000, | |
"fmin": 20, | |
"fmax": 24000, | |
"fmax_for_loss": None, | |
"num_workers": 8, | |
"dist_config": { | |
"dist_backend": "nccl", | |
"dist_url": "tcp://localhost:18273", | |
"world_size": 1 | |
} | |
} | |
def get_vocoder(config, device, mel_bins): | |
name = "HiFi-GAN" | |
speaker = "" | |
if name == "MelGAN": | |
if speaker == "LJSpeech": | |
vocoder = torch.hub.load( | |
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson" | |
) | |
elif speaker == "universal": | |
vocoder = torch.hub.load( | |
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker" | |
) | |
vocoder.mel2wav.eval() | |
vocoder.mel2wav.to(device) | |
elif name == "HiFi-GAN": | |
if(mel_bins == 256): | |
config = get_vocoder_config_48k() | |
config = AttrDict(config) | |
vocoder = Generator_old(config) | |
# print("Load hifigan/g_01080000") | |
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) | |
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) | |
# ckpt = torch_version_orig_mod_remove(ckpt) | |
# vocoder.load_state_dict(ckpt["generator"]) | |
vocoder.eval() | |
vocoder.remove_weight_norm() | |
vocoder = vocoder.to(device) | |
# vocoder = vocoder.half() | |
else: | |
raise ValueError(mel_bins) | |
return vocoder | |
def vocoder_infer(mels, vocoder, lengths=None): | |
with torch.no_grad(): | |
wavs = vocoder(mels).squeeze(1) | |
#wavs = (wavs.cpu().numpy() * 32768).astype("int16") | |
wavs = (wavs.cpu().numpy()) | |
if lengths is not None: | |
wavs = wavs[:, :lengths] | |
# wavs = [wav for wav in wavs] | |
# for i in range(len(mels)): | |
# if lengths is not None: | |
# wavs[i] = wavs[i][: lengths[i]] | |
return wavs | |
def vocoder_chunk_infer(mels, vocoder, lengths=None): | |
chunk_size = 256*4 | |
shift_size = 256*1 | |
ov_size = chunk_size-shift_size | |
# import pdb;pdb.set_trace() | |
for cinx in range(0, mels.shape[2], shift_size): | |
if(cinx==0): | |
wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() | |
num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size | |
wavs = wavs[:,0:num_samples] | |
ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) | |
ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) | |
ov_win = torch.cat([ov_win,1-ov_win],-1) | |
if(cinx+chunk_size>=mels.shape[2]): | |
break | |
else: | |
cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() | |
wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] | |
# wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0 | |
wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) | |
if(cinx+chunk_size>=mels.shape[2]): | |
break | |
# print(wavs.shape) | |
wavs = (wavs.cpu().numpy()) | |
if lengths is not None: | |
wavs = wavs[:, :lengths] | |
# print(wavs.shape) | |
return wavs | |
def synth_one_sample(mel_input, mel_prediction, labels, vocoder): | |
if vocoder is not None: | |
wav_reconstruction = vocoder_infer( | |
mel_input.permute(0, 2, 1), | |
vocoder, | |
) | |
wav_prediction = vocoder_infer( | |
mel_prediction.permute(0, 2, 1), | |
vocoder, | |
) | |
else: | |
wav_reconstruction = wav_prediction = None | |
return wav_reconstruction, wav_prediction | |
class AutoencoderKL(nn.Module): | |
def __init__( | |
self, | |
ddconfig=None, | |
lossconfig=None, | |
batchsize=None, | |
embed_dim=None, | |
time_shuffle=1, | |
subband=1, | |
sampling_rate=16000, | |
ckpt_path=None, | |
reload_from_ckpt=None, | |
ignore_keys=[], | |
image_key="fbank", | |
colorize_nlabels=None, | |
monitor=None, | |
base_learning_rate=1e-5, | |
scale_factor=1 | |
): | |
super().__init__() | |
self.automatic_optimization = False | |
assert ( | |
"mel_bins" in ddconfig.keys() | |
), "mel_bins is not specified in the Autoencoder config" | |
num_mel = ddconfig["mel_bins"] | |
self.image_key = image_key | |
self.sampling_rate = sampling_rate | |
self.encoder = Encoder(**ddconfig) | |
self.decoder = Decoder(**ddconfig) | |
self.loss = None | |
self.subband = int(subband) | |
if self.subband > 1: | |
print("Use subband decomposition %s" % self.subband) | |
assert ddconfig["double_z"] | |
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
if self.image_key == "fbank": | |
self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) | |
self.embed_dim = embed_dim | |
if colorize_nlabels is not None: | |
assert type(colorize_nlabels) == int | |
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
if monitor is not None: | |
self.monitor = monitor | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
self.learning_rate = float(base_learning_rate) | |
# print("Initial learning rate %s" % self.learning_rate) | |
self.time_shuffle = time_shuffle | |
self.reload_from_ckpt = reload_from_ckpt | |
self.reloaded = False | |
self.mean, self.std = None, None | |
self.feature_cache = None | |
self.flag_first_run = True | |
self.train_step = 0 | |
self.logger_save_dir = None | |
self.logger_exp_name = None | |
self.scale_factor = scale_factor | |
print("Num parameters:") | |
print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) | |
print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) | |
print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) | |
def get_log_dir(self): | |
if self.logger_save_dir is None and self.logger_exp_name is None: | |
return os.path.join(self.logger.save_dir, self.logger._project) | |
else: | |
return os.path.join(self.logger_save_dir, self.logger_exp_name) | |
def set_log_dir(self, save_dir, exp_name): | |
self.logger_save_dir = save_dir | |
self.logger_exp_name = exp_name | |
def init_from_ckpt(self, path, ignore_keys=list()): | |
sd = torch.load(path, map_location="cpu")["state_dict"] | |
keys = list(sd.keys()) | |
for k in keys: | |
for ik in ignore_keys: | |
if k.startswith(ik): | |
print("Deleting key {} from state_dict.".format(k)) | |
del sd[k] | |
self.load_state_dict(sd, strict=False) | |
print(f"Restored from {path}") | |
def encode(self, x): | |
# x = self.time_shuffle_operation(x) | |
# x = self.freq_split_subband(x) | |
h = self.encoder(x) | |
moments = self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior | |
def decode(self, z): | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z) | |
# bs, ch, shuffled_timesteps, fbins = dec.size() | |
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) | |
# dec = self.freq_merge_subband(dec) | |
return dec | |
def decode_to_waveform(self, dec): | |
if self.image_key == "fbank": | |
dec = dec.squeeze(1).permute(0, 2, 1) | |
wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) | |
elif self.image_key == "stft": | |
dec = dec.squeeze(1).permute(0, 2, 1) | |
wav_reconstruction = self.wave_decoder(dec) | |
return wav_reconstruction | |
def mel_spectrogram_to_waveform( | |
self, mel, savepath=".", bs=None, name="outwav", save=True | |
): | |
# Mel: [bs, 1, t-steps, fbins] | |
if len(mel.size()) == 4: | |
mel = mel.squeeze(1) | |
mel = mel.permute(0, 2, 1) | |
waveform = self.vocoder(mel) | |
waveform = waveform.cpu().detach().numpy() | |
#if save: | |
# self.save_waveform(waveform, savepath, name) | |
return waveform | |
def encode_first_stage(self, x): | |
return self.encode(x) | |
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
if predict_cids: | |
if z.dim() == 4: | |
z = torch.argmax(z.exp(), dim=1).long() | |
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
z = rearrange(z, "b h w c -> b c h w").contiguous() | |
z = 1.0 / self.scale_factor * z | |
return self.decode(z) | |
def decode_first_stage_withgrad(self, z): | |
z = 1.0 / self.scale_factor * z | |
return self.decode(z) | |
def get_first_stage_encoding(self, encoder_posterior, use_mode=False): | |
if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: | |
z = encoder_posterior.sample() | |
elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: | |
z = encoder_posterior.mode() | |
elif isinstance(encoder_posterior, torch.Tensor): | |
z = encoder_posterior | |
else: | |
raise NotImplementedError( | |
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" | |
) | |
return self.scale_factor * z | |
def visualize_latent(self, input): | |
import matplotlib.pyplot as plt | |
# for i in range(10): | |
# zero_input = torch.zeros_like(input) - 11.59 | |
# zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 | |
# posterior = self.encode(zero_input) | |
# latent = posterior.sample() | |
# avg_latent = torch.mean(latent, dim=1)[0] | |
# plt.imshow(avg_latent.cpu().detach().numpy().T) | |
# plt.savefig("%s.png" % i) | |
# plt.close() | |
np.save("input.npy", input.cpu().detach().numpy()) | |
# zero_input = torch.zeros_like(input) - 11.59 | |
time_input = input.clone() | |
time_input[:, :, :, :32] *= 0 | |
time_input[:, :, :, :32] -= 11.59 | |
np.save("time_input.npy", time_input.cpu().detach().numpy()) | |
posterior = self.encode(time_input) | |
latent = posterior.sample() | |
np.save("time_latent.npy", latent.cpu().detach().numpy()) | |
avg_latent = torch.mean(latent, dim=1) | |
for i in range(avg_latent.size(0)): | |
plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
plt.savefig("freq_%s.png" % i) | |
plt.close() | |
freq_input = input.clone() | |
freq_input[:, :, :512, :] *= 0 | |
freq_input[:, :, :512, :] -= 11.59 | |
np.save("freq_input.npy", freq_input.cpu().detach().numpy()) | |
posterior = self.encode(freq_input) | |
latent = posterior.sample() | |
np.save("freq_latent.npy", latent.cpu().detach().numpy()) | |
avg_latent = torch.mean(latent, dim=1) | |
for i in range(avg_latent.size(0)): | |
plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
plt.savefig("time_%s.png" % i) | |
plt.close() | |
def get_input(self, batch): | |
fname, text, label_indices, waveform, stft, fbank = ( | |
batch["fname"], | |
batch["text"], | |
batch["label_vector"], | |
batch["waveform"], | |
batch["stft"], | |
batch["log_mel_spec"], | |
) | |
# if(self.time_shuffle != 1): | |
# if(fbank.size(1) % self.time_shuffle != 0): | |
# pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) | |
# fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) | |
ret = {} | |
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( | |
fbank.unsqueeze(1), | |
stft.unsqueeze(1), | |
fname, | |
waveform.unsqueeze(1), | |
) | |
return ret | |
def save_wave(self, batch_wav, fname, save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
for wav, name in zip(batch_wav, fname): | |
name = os.path.basename(name) | |
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) | |
def get_last_layer(self): | |
return self.decoder.conv_out.weight | |
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): | |
log = dict() | |
x = batch.to(self.device) | |
if not only_inputs: | |
xrec, posterior = self(x) | |
log["samples"] = self.decode(posterior.sample()) | |
log["reconstructions"] = xrec | |
log["inputs"] = x | |
wavs = self._log_img(log, train=train, index=0, waveform=waveform) | |
return wavs | |
def _log_img(self, log, train=True, index=0, waveform=None): | |
images_input = self.tensor2numpy(log["inputs"][index, 0]).T | |
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T | |
images_samples = self.tensor2numpy(log["samples"][index, 0]).T | |
if train: | |
name = "train" | |
else: | |
name = "val" | |
if self.logger is not None: | |
self.logger.log_image( | |
"img_%s" % name, | |
[images_input, images_reconstruct, images_samples], | |
caption=["input", "reconstruct", "samples"], | |
) | |
inputs, reconstructions, samples = ( | |
log["inputs"], | |
log["reconstructions"], | |
log["samples"], | |
) | |
if self.image_key == "fbank": | |
wav_original, wav_prediction = synth_one_sample( | |
inputs[index], | |
reconstructions[index], | |
labels="validation", | |
vocoder=self.vocoder, | |
) | |
wav_original, wav_samples = synth_one_sample( | |
inputs[index], samples[index], labels="validation", vocoder=self.vocoder | |
) | |
wav_original, wav_samples, wav_prediction = ( | |
wav_original[0], | |
wav_samples[0], | |
wav_prediction[0], | |
) | |
elif self.image_key == "stft": | |
wav_prediction = ( | |
self.decode_to_waveform(reconstructions)[index, 0] | |
.cpu() | |
.detach() | |
.numpy() | |
) | |
wav_samples = ( | |
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() | |
) | |
wav_original = waveform[index, 0].cpu().detach().numpy() | |
if self.logger is not None: | |
self.logger.experiment.log( | |
{ | |
"original_%s" | |
% name: wandb.Audio( | |
wav_original, caption="original", sample_rate=self.sampling_rate | |
), | |
"reconstruct_%s" | |
% name: wandb.Audio( | |
wav_prediction, | |
caption="reconstruct", | |
sample_rate=self.sampling_rate, | |
), | |
"samples_%s" | |
% name: wandb.Audio( | |
wav_samples, caption="samples", sample_rate=self.sampling_rate | |
), | |
} | |
) | |
return wav_original, wav_prediction, wav_samples | |
def tensor2numpy(self, tensor): | |
return tensor.cpu().detach().numpy() | |
def to_rgb(self, x): | |
assert self.image_key == "segmentation" | |
if not hasattr(self, "colorize"): | |
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
x = torch.nn.functional.conv2d(x, weight=self.colorize) | |
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 | |
return x | |
class IdentityFirstStage(torch.nn.Module): | |
def __init__(self, *args, vq_interface=False, **kwargs): | |
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff | |
super().__init__() | |
def encode(self, x, *args, **kwargs): | |
return x | |
def decode(self, x, *args, **kwargs): | |
return x | |
def quantize(self, x, *args, **kwargs): | |
if self.vq_interface: | |
return x, None, [None, None, None] | |
return x | |
def forward(self, x, *args, **kwargs): | |
return x | |
def window_sumsquare( | |
window, | |
n_frames, | |
hop_length, | |
win_length, | |
n_fft, | |
dtype=np.float32, | |
norm=None, | |
): | |
""" | |
# from librosa 0.6 | |
Compute the sum-square envelope of a window function at a given hop length. | |
This is used to estimate modulation effects induced by windowing | |
observations in short-time fourier transforms. | |
Parameters | |
---------- | |
window : string, tuple, number, callable, or list-like | |
Window specification, as in `get_window` | |
n_frames : int > 0 | |
The number of analysis frames | |
hop_length : int > 0 | |
The number of samples to advance between frames | |
win_length : [optional] | |
The length of the window function. By default, this matches `n_fft`. | |
n_fft : int > 0 | |
The length of each analysis frame. | |
dtype : np.dtype | |
The data type of the output | |
Returns | |
------- | |
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` | |
The sum-squared envelope of the window function | |
""" | |
if win_length is None: | |
win_length = n_fft | |
n = n_fft + hop_length * (n_frames - 1) | |
x = np.zeros(n, dtype=dtype) | |
# Compute the squared window at the desired length | |
win_sq = get_window(window, win_length, fftbins=True) | |
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 | |
win_sq = librosa_util.pad_center(win_sq, n_fft) | |
# Fill the envelope | |
for i in range(n_frames): | |
sample = i * hop_length | |
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] | |
return x | |
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | |
""" | |
PARAMS | |
------ | |
C: compression factor | |
""" | |
return normalize_fun(torch.clamp(x, min=clip_val) * C) | |
def dynamic_range_decompression(x, C=1): | |
""" | |
PARAMS | |
------ | |
C: compression factor used to compress | |
""" | |
return torch.exp(x) / C | |
class STFT(torch.nn.Module): | |
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" | |
def __init__(self, filter_length, hop_length, win_length, window="hann"): | |
super(STFT, self).__init__() | |
self.filter_length = filter_length | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.window = window | |
self.forward_transform = None | |
scale = self.filter_length / self.hop_length | |
fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
cutoff = int((self.filter_length / 2 + 1)) | |
fourier_basis = np.vstack( | |
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] | |
) | |
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | |
inverse_basis = torch.FloatTensor( | |
np.linalg.pinv(scale * fourier_basis).T[:, None, :] | |
) | |
if window is not None: | |
assert filter_length >= win_length | |
# get window and zero center pad it to filter_length | |
fft_window = get_window(window, win_length, fftbins=True) | |
fft_window = pad_center(fft_window, size=filter_length) | |
fft_window = torch.from_numpy(fft_window).float() | |
# window the bases | |
forward_basis *= fft_window | |
inverse_basis *= fft_window | |
self.register_buffer("forward_basis", forward_basis.float()) | |
self.register_buffer("inverse_basis", inverse_basis.float()) | |
def transform(self, input_data): | |
device = self.forward_basis.device | |
input_data = input_data.to(device) | |
num_batches = input_data.size(0) | |
num_samples = input_data.size(1) | |
self.num_samples = num_samples | |
# similar to librosa, reflect-pad the input | |
input_data = input_data.view(num_batches, 1, num_samples) | |
input_data = torch.nn.functional.pad( | |
input_data.unsqueeze(1), | |
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), | |
mode="reflect", | |
) | |
input_data = input_data.squeeze(1) | |
forward_transform = torch.nn.functional.conv1d( | |
input_data, | |
torch.autograd.Variable(self.forward_basis, requires_grad=False), | |
stride=self.hop_length, | |
padding=0, | |
)#.cpu() | |
cutoff = int((self.filter_length / 2) + 1) | |
real_part = forward_transform[:, :cutoff, :] | |
imag_part = forward_transform[:, cutoff:, :] | |
magnitude = torch.sqrt(real_part**2 + imag_part**2) | |
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) | |
return magnitude, phase | |
def inverse(self, magnitude, phase): | |
device = self.forward_basis.device | |
magnitude, phase = magnitude.to(device), phase.to(device) | |
recombine_magnitude_phase = torch.cat( | |
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 | |
) | |
inverse_transform = torch.nn.functional.conv_transpose1d( | |
recombine_magnitude_phase, | |
torch.autograd.Variable(self.inverse_basis, requires_grad=False), | |
stride=self.hop_length, | |
padding=0, | |
) | |
if self.window is not None: | |
window_sum = window_sumsquare( | |
self.window, | |
magnitude.size(-1), | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
n_fft=self.filter_length, | |
dtype=np.float32, | |
) | |
# remove modulation effects | |
approx_nonzero_indices = torch.from_numpy( | |
np.where(window_sum > tiny(window_sum))[0] | |
) | |
window_sum = torch.autograd.Variable( | |
torch.from_numpy(window_sum), requires_grad=False | |
) | |
window_sum = window_sum | |
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ | |
approx_nonzero_indices | |
] | |
# scale by hop ratio | |
inverse_transform *= float(self.filter_length) / self.hop_length | |
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] | |
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] | |
return inverse_transform | |
def forward(self, input_data): | |
self.magnitude, self.phase = self.transform(input_data) | |
reconstruction = self.inverse(self.magnitude, self.phase) | |
return reconstruction | |
class TacotronSTFT(torch.nn.Module): | |
def __init__( | |
self, | |
filter_length, | |
hop_length, | |
win_length, | |
n_mel_channels, | |
sampling_rate, | |
mel_fmin, | |
mel_fmax, | |
): | |
super(TacotronSTFT, self).__init__() | |
self.n_mel_channels = n_mel_channels | |
self.sampling_rate = sampling_rate | |
self.stft_fn = STFT(filter_length, hop_length, win_length) | |
mel_basis = librosa_mel_fn( | |
sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax | |
) | |
mel_basis = torch.from_numpy(mel_basis).float() | |
self.register_buffer("mel_basis", mel_basis) | |
def spectral_normalize(self, magnitudes, normalize_fun): | |
output = dynamic_range_compression(magnitudes, normalize_fun) | |
return output | |
def spectral_de_normalize(self, magnitudes): | |
output = dynamic_range_decompression(magnitudes) | |
return output | |
def mel_spectrogram(self, y, normalize_fun=torch.log): | |
"""Computes mel-spectrograms from a batch of waves | |
PARAMS | |
------ | |
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] | |
RETURNS | |
------- | |
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) | |
""" | |
assert torch.min(y.data) >= -1, torch.min(y.data) | |
assert torch.max(y.data) <= 1, torch.max(y.data) | |
magnitudes, phases = self.stft_fn.transform(y) | |
magnitudes = magnitudes.data | |
mel_output = torch.matmul(self.mel_basis, magnitudes) | |
mel_output = self.spectral_normalize(mel_output, normalize_fun) | |
energy = torch.norm(magnitudes, dim=1) | |
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) | |
return mel_output, log_magnitudes, energy | |
def build_pretrained_models(ckpt): | |
checkpoint = torch.load(ckpt, map_location="cpu") | |
scale_factor = checkpoint["state_dict"]["scale_factor"].item() | |
print("scale_factor: ", scale_factor) | |
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} | |
config = { | |
"preprocessing": { | |
"audio": { | |
"sampling_rate": 48000, | |
"max_wav_value": 32768, | |
"duration": 10.24 | |
}, | |
"stft": { | |
"filter_length": 2048, | |
"hop_length": 480, | |
"win_length": 2048 | |
}, | |
"mel": { | |
"n_mel_channels": 256, | |
"mel_fmin": 20, | |
"mel_fmax": 24000 | |
} | |
}, | |
"model": { | |
"params": { | |
"first_stage_config": { | |
"params": { | |
"sampling_rate": 48000, | |
"batchsize": 4, | |
"monitor": "val/rec_loss", | |
"image_key": "fbank", | |
"subband": 1, | |
"embed_dim": 16, | |
"time_shuffle": 1, | |
"lossconfig": { | |
"target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", | |
"params": { | |
"disc_start": 50001, | |
"kl_weight": 1000, | |
"disc_weight": 0.5, | |
"disc_in_channels": 1 | |
} | |
}, | |
"ddconfig": { | |
"double_z": True, | |
"mel_bins": 256, | |
"z_channels": 16, | |
"resolution": 256, | |
"downsample_time": False, | |
"in_channels": 1, | |
"out_ch": 1, | |
"ch": 128, | |
"ch_mult": [ | |
1, | |
2, | |
4, | |
8 | |
], | |
"num_res_blocks": 2, | |
"attn_resolutions": [], | |
"dropout": 0 | |
} | |
} | |
}, | |
} | |
} | |
} | |
vae_config = config["model"]["params"]["first_stage_config"]["params"] | |
vae_config["scale_factor"] = scale_factor | |
vae = AutoencoderKL(**vae_config) | |
vae.load_state_dict(vae_state_dict) | |
fn_STFT = TacotronSTFT( | |
config["preprocessing"]["stft"]["filter_length"], | |
config["preprocessing"]["stft"]["hop_length"], | |
config["preprocessing"]["stft"]["win_length"], | |
config["preprocessing"]["mel"]["n_mel_channels"], | |
config["preprocessing"]["audio"]["sampling_rate"], | |
config["preprocessing"]["mel"]["mel_fmin"], | |
config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
vae.eval() | |
fn_STFT.eval() | |
return vae, fn_STFT | |
if __name__=="__main__": | |
vae, stft = build_pretrained_models() | |
vae, stft = vae.cuda(), stft.cuda() | |
json_file="outputs/wav.scp" | |
out_path="outputs/Music_inverse" | |
wavform = torch.randn(2,int(48000*10.24)) | |
mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) | |
mel = mel.unsqueeze(1).cuda() | |
print(mel.shape) | |
# true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0) | |
# print(true_latent.shape) | |
true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) | |
print(true_latent.shape) | |
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() | |
true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) | |
print("111", true_latent.size()) | |
mel = vae.decode_first_stage(true_latent) | |
print("222", mel.size()) | |
audio = vae.decode_to_waveform(mel) | |
print("333", audio.shape) | |
# out_file = out_path + "/" + os.path.basename(fname.strip()) | |
# sf.write(out_file, audio[0], samplerate=48000) | |