|
import copy |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from common.audio import stft |
|
from torch.nn.utils import weight_norm, spectral_norm |
|
from torch.nn import Conv1d |
|
from einops import rearrange |
|
|
|
class SpecDiscriminator(nn.Module): |
|
def __init__(self, |
|
stft_params=None, |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_sizes=(7, 3), |
|
channels=32, |
|
max_downsample_channels=512, |
|
downsample_scales=(2, 2, 2), |
|
use_weight_norm=True, |
|
): |
|
super().__init__() |
|
|
|
if stft_params is None: |
|
stft_params = { |
|
'fft_sizes': [1024, 2048, 512], |
|
'hop_sizes': [120, 240, 50], |
|
'win_lengths': [600, 1200, 240], |
|
'window': 'hann_window' |
|
} |
|
|
|
self.stft_params = stft_params |
|
|
|
self.model = nn.ModuleDict() |
|
for i in range(len(stft_params['fft_sizes'])): |
|
self.model["disc_" + str(i)] = NLayerSpecDiscriminator( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_sizes=kernel_sizes, |
|
channels=channels, |
|
max_downsample_channels=max_downsample_channels, |
|
downsample_scales=downsample_scales, |
|
) |
|
|
|
if use_weight_norm: |
|
self.apply_weight_norm() |
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
results = [] |
|
i = 0 |
|
x = x.squeeze(1) |
|
for _, disc in self.model.items(): |
|
spec = stft(x, self.stft_params['fft_sizes'][i], self.stft_params['hop_sizes'][i], |
|
self.stft_params['win_lengths'][i], |
|
window=getattr(torch, self.stft_params['window'])(self.stft_params['win_lengths'][i])) |
|
spec = spec.transpose(1, 2).unsqueeze(1) |
|
results.append(disc(spec)) |
|
i += 1 |
|
return results |
|
|
|
def remove_weight_norm(self): |
|
def _remove_weight_norm(m): |
|
try: |
|
torch.nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
self.apply(_remove_weight_norm) |
|
|
|
def apply_weight_norm(self): |
|
def _apply_weight_norm(m): |
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
torch.nn.utils.weight_norm(m) |
|
self.apply(_apply_weight_norm) |
|
|
|
def reset_parameters(self): |
|
def _reset_parameters(m): |
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
m.weight.data.normal_(0.0, 0.02) |
|
self.apply(_reset_parameters) |
|
|
|
|
|
class NLayerSpecDiscriminator(nn.Module): |
|
def __init__(self, |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_sizes=(5, 3), |
|
channels=32, |
|
max_downsample_channels=512, |
|
downsample_scales=(2, 2, 2)): |
|
super().__init__() |
|
|
|
|
|
assert kernel_sizes[0] % 2 == 1 |
|
assert kernel_sizes[1] % 2 == 1 |
|
|
|
model = nn.ModuleDict() |
|
|
|
model["layer_0"] = nn.Sequential( |
|
nn.Conv2d(in_channels, channels, |
|
kernel_size=kernel_sizes[0], |
|
stride=2, |
|
padding=kernel_sizes[0] // 2), |
|
nn.LeakyReLU(0.2, True), |
|
) |
|
|
|
in_chs = channels |
|
for i, downsample_scale in enumerate(downsample_scales): |
|
out_chs = min(in_chs * downsample_scale, max_downsample_channels) |
|
|
|
model[f"layer_{i + 1}"] = nn.Sequential( |
|
nn.Conv2d( |
|
in_chs, |
|
out_chs, |
|
kernel_size=downsample_scale * 2 + 1, |
|
stride=downsample_scale, |
|
padding=downsample_scale, |
|
), |
|
nn.LeakyReLU(0.2, True), |
|
) |
|
in_chs = out_chs |
|
|
|
out_chs = min(in_chs * 2, max_downsample_channels) |
|
model[f"layer_{len(downsample_scales) + 1}"] = nn.Sequential( |
|
nn.Conv2d(in_chs, out_chs, kernel_size=kernel_sizes[1], |
|
padding=kernel_sizes[1] // 2), |
|
nn.LeakyReLU(0.2, True), |
|
) |
|
|
|
model[f"layer_{len(downsample_scales) + 2}"] = nn.Conv2d( |
|
out_chs, out_channels, kernel_size=kernel_sizes[1], |
|
padding=kernel_sizes[1] // 2) |
|
|
|
self.model = model |
|
|
|
def forward(self, x): |
|
results = [] |
|
for _, layer in self.model.items(): |
|
x = layer(x) |
|
results.append(x) |
|
return results |
|
|
|
|