Spaces:
Running
Running
File size: 2,799 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import numpy as np
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
class MelganDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4, 4),
groups_denominator=4,
):
super().__init__()
self.layers = nn.ModuleList()
layer_kernel_size = np.prod(kernel_sizes)
layer_padding = (layer_kernel_size - 1) // 2
# initial layer
self.layers += [
nn.Sequential(
nn.ReflectionPad1d(layer_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
nn.LeakyReLU(0.2, inplace=True),
)
]
# downsampling layers
layer_in_channels = base_channels
for downsample_factor in downsample_factors:
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
layer_kernel_size = downsample_factor * 10 + 1
layer_padding = (layer_kernel_size - 1) // 2
layer_groups = layer_in_channels // groups_denominator
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_in_channels,
layer_out_channels,
kernel_size=layer_kernel_size,
stride=downsample_factor,
padding=layer_padding,
groups=layer_groups,
)
),
nn.LeakyReLU(0.2, inplace=True),
)
]
layer_in_channels = layer_out_channels
# last 2 layers
layer_padding1 = (kernel_sizes[0] - 1) // 2
layer_padding2 = (kernel_sizes[1] - 1) // 2
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_out_channels,
layer_out_channels,
kernel_size=kernel_sizes[0],
stride=1,
padding=layer_padding1,
)
),
nn.LeakyReLU(0.2, inplace=True),
),
weight_norm(
nn.Conv1d(
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
)
),
]
def forward(self, x):
feats = []
for layer in self.layers:
x = layer(x)
feats.append(x)
return x, feats
|