Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
import torchaudio | |
import torch | |
from torch import nn | |
from einops import rearrange | |
from ...modules import NormConv2d | |
from .base import MultiDiscriminator, MultiDiscriminatorOutputType | |
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): | |
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) | |
class DiscriminatorSTFT(nn.Module): | |
"""STFT sub-discriminator. | |
Args: | |
filters (int): Number of filters in convolutions. | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
n_fft (int): Size of FFT for each scale. | |
hop_length (int): Length of hop between STFT windows for each scale. | |
kernel_size (tuple of int): Inner Conv2d kernel sizes. | |
stride (tuple of int): Inner Conv2d strides. | |
dilations (list of int): Inner Conv2d dilation on the time dimension. | |
win_length (int): Window size for each scale. | |
normalized (bool): Whether to normalize by magnitude after stft. | |
norm (str): Normalization method. | |
activation (str): Activation function. | |
activation_params (dict): Parameters to provide to the activation function. | |
growth (int): Growth factor for the filters. | |
""" | |
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, | |
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, | |
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], | |
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', | |
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): | |
super().__init__() | |
assert len(kernel_size) == 2 | |
assert len(stride) == 2 | |
self.filters = filters | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.normalized = normalized | |
self.activation = getattr(torch.nn, activation)(**activation_params) | |
self.spec_transform = torchaudio.transforms.Spectrogram( | |
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, | |
normalized=self.normalized, center=False, pad_mode=None, power=None) | |
spec_channels = 2 * self.in_channels | |
self.convs = nn.ModuleList() | |
self.convs.append( | |
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) | |
) | |
in_chs = min(filters_scale * self.filters, max_filters) | |
for i, dilation in enumerate(dilations): | |
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) | |
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, | |
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), | |
norm=norm)) | |
in_chs = out_chs | |
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) | |
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), | |
padding=get_2d_padding((kernel_size[0], kernel_size[0])), | |
norm=norm)) | |
self.conv_post = NormConv2d(out_chs, self.out_channels, | |
kernel_size=(kernel_size[0], kernel_size[0]), | |
padding=get_2d_padding((kernel_size[0], kernel_size[0])), | |
norm=norm) | |
def forward(self, x: torch.Tensor): | |
fmap = [] | |
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] | |
z = torch.cat([z.real, z.imag], dim=1) | |
z = rearrange(z, 'b c w t -> b c t w') | |
for i, layer in enumerate(self.convs): | |
z = layer(z) | |
z = self.activation(z) | |
fmap.append(z) | |
z = self.conv_post(z) | |
return z, fmap | |
class MultiScaleSTFTDiscriminator(MultiDiscriminator): | |
"""Multi-Scale STFT (MS-STFT) discriminator. | |
Args: | |
filters (int): Number of filters in convolutions. | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
sep_channels (bool): Separate channels to distinct samples for stereo support. | |
n_ffts (Sequence[int]): Size of FFT for each scale. | |
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale. | |
win_lengths (Sequence[int]): Window size for each scale. | |
**kwargs: Additional args for STFTDiscriminator. | |
""" | |
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, | |
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], | |
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): | |
super().__init__() | |
assert len(n_ffts) == len(hop_lengths) == len(win_lengths) | |
self.sep_channels = sep_channels | |
self.discriminators = nn.ModuleList([ | |
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, | |
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) | |
for i in range(len(n_ffts)) | |
]) | |
def num_discriminators(self): | |
return len(self.discriminators) | |
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: | |
B, C, T = x.shape | |
return x.view(-1, 1, T) | |
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: | |
logits = [] | |
fmaps = [] | |
for disc in self.discriminators: | |
logit, fmap = disc(x) | |
logits.append(logit) | |
fmaps.append(fmap) | |
return logits, fmaps | |