Spaces:
Running
Running
from typing import Tuple | |
import torch | |
import torch.nn as nn # pylint: disable=consider-using-from-import | |
import torch.nn.functional as F | |
from torch.nn.utils import parametrize | |
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor | |
def calc_same_padding(kernel_size: int) -> Tuple[int, int]: | |
pad = kernel_size // 2 | |
return (pad, pad - (kernel_size + 1) % 2) | |
class ConvNorm(nn.Module): | |
"""A 1-dimensional convolutional layer with optional weight normalization. | |
This layer wraps a 1D convolutional layer from PyTorch and applies | |
optional weight normalization. The layer can be used in a similar way to | |
the convolutional layers in PyTorch's `torch.nn` module. | |
Args: | |
in_channels (int): The number of channels in the input signal. | |
out_channels (int): The number of channels in the output signal. | |
kernel_size (int, optional): The size of the convolving kernel. | |
Defaults to 1. | |
stride (int, optional): The stride of the convolution. Defaults to 1. | |
padding (int, optional): Zero-padding added to both sides of the input. | |
If `None`, the padding will be calculated so that the output has | |
the same length as the input. Defaults to `None`. | |
dilation (int, optional): Spacing between kernel elements. Defaults to 1. | |
bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. | |
w_init_gain (str, optional): The weight initialization function to use. | |
Can be either 'linear' or 'relu'. Defaults to 'linear'. | |
use_weight_norm (bool, optional): If `True`, apply weight normalization | |
to the convolutional weights. Defaults to `False`. | |
Shapes: | |
- Input: :math:`[N, D, T]` | |
- Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=None, | |
dilation=1, | |
bias=True, | |
w_init_gain="linear", | |
use_weight_norm=False, | |
): | |
super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments | |
if padding is None: | |
assert kernel_size % 2 == 1 | |
padding = int(dilation * (kernel_size - 1) / 2) | |
self.kernel_size = kernel_size | |
self.dilation = dilation | |
self.use_weight_norm = use_weight_norm | |
conv_fn = nn.Conv1d | |
self.conv = conv_fn( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
) | |
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) | |
if self.use_weight_norm: | |
self.conv = nn.utils.parametrizations.weight_norm(self.conv) | |
def forward(self, signal, mask=None): | |
conv_signal = self.conv(signal) | |
if mask is not None: | |
# always re-zero output if mask is | |
# available to match zero-padding | |
conv_signal = conv_signal * mask | |
return conv_signal | |
class ConvLSTMLinear(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
out_dim, | |
n_layers=2, | |
n_channels=256, | |
kernel_size=3, | |
p_dropout=0.1, | |
lstm_type="bilstm", | |
use_linear=True, | |
): | |
super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments | |
self.out_dim = out_dim | |
self.lstm_type = lstm_type | |
self.use_linear = use_linear | |
self.dropout = nn.Dropout(p=p_dropout) | |
convolutions = [] | |
for i in range(n_layers): | |
conv_layer = ConvNorm( | |
in_dim if i == 0 else n_channels, | |
n_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=int((kernel_size - 1) / 2), | |
dilation=1, | |
w_init_gain="relu", | |
) | |
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight") | |
convolutions.append(conv_layer) | |
self.convolutions = nn.ModuleList(convolutions) | |
if not self.use_linear: | |
n_channels = out_dim | |
if self.lstm_type != "": | |
use_bilstm = False | |
lstm_channels = n_channels | |
if self.lstm_type == "bilstm": | |
use_bilstm = True | |
lstm_channels = int(n_channels // 2) | |
self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) | |
lstm_norm_fn_pntr = nn.utils.spectral_norm | |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") | |
if self.lstm_type == "bilstm": | |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") | |
if self.use_linear: | |
self.dense = nn.Linear(n_channels, out_dim) | |
def run_padded_sequence(self, context, lens): | |
context_embedded = [] | |
for b_ind in range(context.size()[0]): # TODO: speed up | |
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() | |
for conv in self.convolutions: | |
curr_context = self.dropout(F.relu(conv(curr_context))) | |
context_embedded.append(curr_context[0].transpose(0, 1)) | |
context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) | |
return context | |
def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use | |
lens_sorted, ids_sorted = torch.sort(lens, descending=True) | |
unsort_ids = [0] * lens.size(0) | |
for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate | |
unsort_ids[ids_sorted[i]] = i | |
lens_sorted = lens_sorted.long().cpu() | |
context = context[ids_sorted] | |
context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) | |
context = fn(context)[0] | |
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] | |
# map back to original indices | |
context = context[unsort_ids] | |
return context | |
def forward(self, context, lens): | |
if context.size()[0] > 1: | |
context = self.run_padded_sequence(context, lens) | |
# to B, D, T | |
context = context.transpose(1, 2) | |
else: | |
for conv in self.convolutions: | |
context = self.dropout(F.relu(conv(context))) | |
if self.lstm_type != "": | |
context = context.transpose(1, 2) | |
self.bilstm.flatten_parameters() | |
if lens is not None: | |
context = self.run_unsorted_inputs(self.bilstm, context, lens) | |
else: | |
context = self.bilstm(context)[0] | |
context = context.transpose(1, 2) | |
x_hat = context | |
if self.use_linear: | |
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) | |
return x_hat | |
class DepthWiseConv1d(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): | |
super().__init__() | |
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.conv(x) | |
class PointwiseConv1d(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = True, | |
): | |
super().__init__() | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.conv(x) | |
class BSConv1d(nn.Module): | |
"""https://arxiv.org/pdf/2003.13549.pdf""" | |
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): | |
super().__init__() | |
self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) | |
self.depthwise = nn.Conv1d( | |
channels_out, | |
channels_out, | |
kernel_size=kernel_size, | |
padding=padding, | |
groups=channels_out, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.pointwise(x) | |
x2 = self.depthwise(x1) | |
return x2 | |
class BSConv2d(nn.Module): | |
"""https://arxiv.org/pdf/2003.13549.pdf""" | |
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): | |
super().__init__() | |
self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) | |
self.depthwise = nn.Conv2d( | |
channels_out, | |
channels_out, | |
kernel_size=kernel_size, | |
padding=padding, | |
groups=channels_out, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.pointwise(x) | |
x2 = self.depthwise(x1) | |
return x2 | |
class Conv1dGLU(nn.Module): | |
"""From DeepVoice 3""" | |
def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): | |
super().__init__() | |
self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) | |
self.embedding_proj = nn.Linear(embedding_dim, d_model) | |
self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) | |
self.softsign = torch.nn.Softsign() | |
def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: | |
x = x.permute((0, 2, 1)) | |
residual = x | |
x = self.conv(x) | |
splitdim = 1 | |
a, b = x.split(x.size(splitdim) // 2, dim=splitdim) | |
embeddings = self.embedding_proj(embeddings).unsqueeze(2) | |
softsign = self.softsign(embeddings) | |
softsign = softsign.expand_as(a) | |
a = a + softsign | |
x = a * torch.sigmoid(b) | |
x = x + residual | |
x = x * self.sqrt | |
x = x.permute((0, 2, 1)) | |
return x | |
class ConvTransposed(nn.Module): | |
""" | |
A 1D convolutional transposed layer for PyTorch. | |
This layer applies a 1D convolutional transpose operation to its input tensor, | |
where the number of channels of the input tensor is the same as the number of channels of the output tensor. | |
Attributes: | |
in_channels (int): The number of channels in the input tensor. | |
out_channels (int): The number of channels in the output tensor. | |
kernel_size (int): The size of the convolutional kernel. Default: 1. | |
padding (int): The number of padding elements to add to the input tensor. Default: 0. | |
conv (BSConv1d): The 1D convolutional transpose layer. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int = 1, | |
padding: int = 0, | |
): | |
super().__init__() | |
self.conv = BSConv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=padding, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x.contiguous().transpose(1, 2) | |
x = self.conv(x) | |
x = x.contiguous().transpose(1, 2) | |
return x | |
class DepthwiseConvModule(nn.Module): | |
def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): | |
super().__init__() | |
padding = calc_same_padding(kernel_size) | |
self.depthwise = nn.Conv1d( | |
dim, | |
dim * expansion, | |
kernel_size=kernel_size, | |
padding=padding[0], | |
groups=dim, | |
) | |
self.act = nn.LeakyReLU(lrelu_slope) | |
self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) | |
self.ln = nn.LayerNorm(dim) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.ln(x) | |
x = x.permute((0, 2, 1)) | |
x = self.depthwise(x) | |
x = self.act(x) | |
x = self.out(x) | |
x = x.permute((0, 2, 1)) | |
return x | |
class AddCoords(nn.Module): | |
def __init__(self, rank: int, with_r: bool = False): | |
super().__init__() | |
self.rank = rank | |
self.with_r = with_r | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.rank == 1: | |
batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable | |
xx_range = torch.arange(dim_x, dtype=torch.int32) | |
xx_channel = xx_range[None, None, :] | |
xx_channel = xx_channel.float() / (dim_x - 1) | |
xx_channel = xx_channel * 2 - 1 | |
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) | |
xx_channel = xx_channel.to(x.device) | |
out = torch.cat([x, xx_channel], dim=1) | |
if self.with_r: | |
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) | |
out = torch.cat([out, rr], dim=1) | |
elif self.rank == 2: | |
batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape | |
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) | |
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) | |
xx_range = torch.arange(dim_y, dtype=torch.int32) | |
yy_range = torch.arange(dim_x, dtype=torch.int32) | |
xx_range = xx_range[None, None, :, None] | |
yy_range = yy_range[None, None, :, None] | |
xx_channel = torch.matmul(xx_range, xx_ones) | |
yy_channel = torch.matmul(yy_range, yy_ones) | |
# transpose y | |
yy_channel = yy_channel.permute(0, 1, 3, 2) | |
xx_channel = xx_channel.float() / (dim_y - 1) | |
yy_channel = yy_channel.float() / (dim_x - 1) | |
xx_channel = xx_channel * 2 - 1 | |
yy_channel = yy_channel * 2 - 1 | |
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) | |
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) | |
xx_channel = xx_channel.to(x.device) | |
yy_channel = yy_channel.to(x.device) | |
out = torch.cat([x, xx_channel, yy_channel], dim=1) | |
if self.with_r: | |
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) | |
out = torch.cat([out, rr], dim=1) | |
elif self.rank == 3: | |
batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape | |
xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) | |
yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) | |
zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) | |
xy_range = torch.arange(dim_y, dtype=torch.int32) | |
xy_range = xy_range[None, None, None, :, None] | |
yz_range = torch.arange(dim_z, dtype=torch.int32) | |
yz_range = yz_range[None, None, None, :, None] | |
zx_range = torch.arange(dim_x, dtype=torch.int32) | |
zx_range = zx_range[None, None, None, :, None] | |
xy_channel = torch.matmul(xy_range, xx_ones) | |
xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) | |
yz_channel = torch.matmul(yz_range, yy_ones) | |
yz_channel = yz_channel.permute(0, 1, 3, 4, 2) | |
yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) | |
zx_channel = torch.matmul(zx_range, zz_ones) | |
zx_channel = zx_channel.permute(0, 1, 4, 2, 3) | |
zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) | |
xx_channel = xx_channel.to(x.device) | |
yy_channel = yy_channel.to(x.device) | |
zz_channel = zz_channel.to(x.device) | |
out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) | |
if self.with_r: | |
rr = torch.sqrt( | |
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) | |
) | |
out = torch.cat([out, rr], dim=1) | |
else: | |
raise NotImplementedError | |
return out | |
class CoordConv1d(nn.modules.conv.Conv1d): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
groups: int = 1, | |
bias: bool = True, | |
with_r: bool = False, | |
): | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
groups, | |
bias, | |
) | |
self.rank = 1 | |
self.addcoords = AddCoords(self.rank, with_r) | |
self.conv = nn.Conv1d( | |
in_channels + self.rank + int(with_r), | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
groups, | |
bias, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.addcoords(x) | |
x = self.conv(x) | |
return x | |
class CoordConv2d(nn.modules.conv.Conv2d): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
groups: int = 1, | |
bias: bool = True, | |
with_r: bool = False, | |
): | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
groups, | |
bias, | |
) | |
self.rank = 2 | |
self.addcoords = AddCoords(self.rank, with_r) | |
self.conv = nn.Conv2d( | |
in_channels + self.rank + int(with_r), | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
groups, | |
bias, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.addcoords(x) | |
x = self.conv(x) | |
return x | |
class LVCBlock(torch.nn.Module): | |
"""the location-variable convolutions""" | |
def __init__( # pylint: disable=dangerous-default-value | |
self, | |
in_channels, | |
cond_channels, | |
stride, | |
dilations=[1, 3, 9, 27], | |
lReLU_slope=0.2, | |
conv_kernel_size=3, | |
cond_hop_length=256, | |
kpnet_hidden_channels=64, | |
kpnet_conv_size=3, | |
kpnet_dropout=0.0, | |
): | |
super().__init__() | |
self.cond_hop_length = cond_hop_length | |
self.conv_layers = len(dilations) | |
self.conv_kernel_size = conv_kernel_size | |
self.kernel_predictor = KernelPredictor( | |
cond_channels=cond_channels, | |
conv_in_channels=in_channels, | |
conv_out_channels=2 * in_channels, | |
conv_layers=len(dilations), | |
conv_kernel_size=conv_kernel_size, | |
kpnet_hidden_channels=kpnet_hidden_channels, | |
kpnet_conv_size=kpnet_conv_size, | |
kpnet_dropout=kpnet_dropout, | |
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, | |
) | |
self.convt_pre = nn.Sequential( | |
nn.LeakyReLU(lReLU_slope), | |
nn.utils.parametrizations.weight_norm( | |
nn.ConvTranspose1d( | |
in_channels, | |
in_channels, | |
2 * stride, | |
stride=stride, | |
padding=stride // 2 + stride % 2, | |
output_padding=stride % 2, | |
) | |
), | |
) | |
self.conv_blocks = nn.ModuleList() | |
for dilation in dilations: | |
self.conv_blocks.append( | |
nn.Sequential( | |
nn.LeakyReLU(lReLU_slope), | |
nn.utils.parametrizations.weight_norm( | |
nn.Conv1d( | |
in_channels, | |
in_channels, | |
conv_kernel_size, | |
padding=dilation * (conv_kernel_size - 1) // 2, | |
dilation=dilation, | |
) | |
), | |
nn.LeakyReLU(lReLU_slope), | |
) | |
) | |
def forward(self, x, c): | |
"""forward propagation of the location-variable convolutions. | |
Args: | |
x (Tensor): the input sequence (batch, in_channels, in_length) | |
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) | |
Returns: | |
Tensor: the output sequence (batch, in_channels, in_length) | |
""" | |
_, in_channels, _ = x.shape # (B, c_g, L') | |
x = self.convt_pre(x) # (B, c_g, stride * L') | |
kernels, bias = self.kernel_predictor(c) | |
for i, conv in enumerate(self.conv_blocks): | |
output = conv(x) # (B, c_g, stride * L') | |
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) | |
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) | |
output = self.location_variable_convolution( | |
output, k, b, hop_size=self.cond_hop_length | |
) # (B, 2 * c_g, stride * L'): LVC | |
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( | |
output[:, in_channels:, :] | |
) # (B, c_g, stride * L'): GAU | |
return x | |
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use | |
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. | |
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. | |
Args: | |
x (Tensor): the input sequence (batch, in_channels, in_length). | |
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) | |
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) | |
dilation (int): the dilation of convolution. | |
hop_size (int): the hop_size of the conditioning sequence. | |
Returns: | |
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). | |
""" | |
batch, _, in_length = x.shape | |
batch, _, out_channels, kernel_size, kernel_length = kernel.shape | |
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" | |
padding = dilation * int((kernel_size - 1) / 2) | |
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) | |
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) | |
if hop_size < dilation: | |
x = F.pad(x, (0, dilation), "constant", 0) | |
x = x.unfold( | |
3, dilation, dilation | |
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) | |
x = x[:, :, :, :, :hop_size] | |
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) | |
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) | |
o = torch.einsum("bildsk,biokl->bolsd", x, kernel) | |
o = o.to(memory_format=torch.channels_last_3d) | |
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) | |
o = o + bias | |
o = o.contiguous().view(batch, out_channels, -1) | |
return o | |
def remove_weight_norm(self): | |
self.kernel_predictor.remove_weight_norm() | |
parametrize.remove_parametrizations(self.convt_pre[1], "weight") | |
for block in self.conv_blocks: | |
parametrize.remove_parametrizations(block[1], "weight") | |