import copy import torch import torch.nn as nn from s3prl import Output from s3prl.nn.vq_apc import VqApcLayer class MaskConvBlock(nn.Module): """ Masked Convolution Blocks as described in NPC paper """ def __init__(self, input_size, hidden_size, kernel_size, mask_size): super(MaskConvBlock, self).__init__() assert kernel_size - mask_size > 0, "Mask > kernel somewhere in the model" # CNN for computing feature (ToDo: other activation?) self.act = nn.Tanh() self.pad_size = (kernel_size - 1) // 2 self.conv = nn.Conv1d( in_channels=input_size, out_channels=hidden_size, kernel_size=kernel_size, padding=self.pad_size, ) # Fixed mask for NPC mask_head = (kernel_size - mask_size) // 2 mask_tail = mask_head + mask_size conv_mask = torch.ones_like(self.conv.weight) conv_mask[:, :, mask_head:mask_tail] = 0 self.register_buffer("conv_mask", conv_mask) def forward(self, feat): feat = nn.functional.conv1d( feat, self.conv_mask * self.conv.weight, bias=self.conv.bias, padding=self.pad_size, ) feat = feat.permute(0, 2, 1) # BxCxT -> BxTxC feat = self.act(feat) return feat class ConvBlock(nn.Module): """ Convolution Blocks as described in NPC paper """ def __init__( self, input_size, hidden_size, residual, dropout, batch_norm, activate ): super(ConvBlock, self).__init__() self.residual = residual if activate == "relu": self.act = nn.ReLU() elif activate == "tanh": self.act = nn.Tanh() else: raise NotImplementedError self.conv = nn.Conv1d( input_size, hidden_size, kernel_size=3, stride=1, padding=1 ) self.linear = nn.Conv1d( hidden_size, hidden_size, kernel_size=1, stride=1, padding=0 ) self.batch_norm = batch_norm if batch_norm: self.bn1 = nn.BatchNorm1d(hidden_size) self.bn2 = nn.BatchNorm1d(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, feat): res = feat out = self.conv(feat) if self.batch_norm: out = self.bn1(out) out = self.act(out) out = self.linear(out) if self.batch_norm: out = self.bn2(out) out = self.dropout(out) if self.residual: out = out + res return self.act(out) class CnnNpc(nn.Module): """ The NPC model with stacked ConvBlocks & Masked ConvBlocks """ def __init__( self, input_size, hidden_size, n_blocks, dropout, residual, kernel_size, mask_size, vq=None, batch_norm=True, activate="relu", disable_cross_layer=False, dim_bottleneck=None, ): super(CnnNpc, self).__init__() # Setup assert kernel_size % 2 == 1, "Kernel size can only be odd numbers" assert mask_size % 2 == 1, "Mask size can only be odd numbers" assert n_blocks >= 1, "At least 1 block needed" self.code_dim = hidden_size self.n_blocks = n_blocks self.input_mask_size = mask_size self.kernel_size = kernel_size self.disable_cross_layer = disable_cross_layer self.apply_vq = vq is not None self.apply_ae = dim_bottleneck is not None if self.apply_ae: assert not self.apply_vq self.dim_bottleneck = dim_bottleneck # Build blocks self.blocks, self.masked_convs = [], [] cur_mask_size = mask_size for i in range(n_blocks): h_dim = input_size if i == 0 else hidden_size res = False if i == 0 else residual # ConvBlock self.blocks.append( ConvBlock(h_dim, hidden_size, res, dropout, batch_norm, activate) ) # Masked ConvBlock on each or last layer cur_mask_size = cur_mask_size + 2 if self.disable_cross_layer and (i != (n_blocks - 1)): self.masked_convs.append(None) else: self.masked_convs.append( MaskConvBlock(hidden_size, hidden_size, kernel_size, cur_mask_size) ) self.blocks = nn.ModuleList(self.blocks) self.masked_convs = nn.ModuleList(self.masked_convs) # Creates N-group VQ if self.apply_vq: self.vq_layers = [] vq_config = copy.deepcopy(vq) codebook_size = vq_config.pop("codebook_size") self.vq_code_dims = vq_config.pop("code_dim") assert len(self.vq_code_dims) == len(codebook_size) assert sum(self.vq_code_dims) == hidden_size for cs, cd in zip(codebook_size, self.vq_code_dims): self.vq_layers.append( VqApcLayer( input_size=cd, code_dim=cd, codebook_size=cs, **vq_config ) ) self.vq_layers = nn.ModuleList(self.vq_layers) # Back to spectrogram if self.apply_ae: self.ae_bottleneck = nn.Linear(hidden_size, self.dim_bottleneck, bias=False) self.postnet = nn.Linear(self.dim_bottleneck, input_size) else: self.postnet = nn.Linear(hidden_size, input_size) def create_msg(self): msg_list = [] msg_list.append( "Model spec.| Method = NPC\t| # of Blocks = {}\t".format(self.n_blocks) ) msg_list.append( " | Desired input mask size = {}".format(self.input_mask_size) ) msg_list.append( " | Receptive field size = {}".format( self.kernel_size + 2 * self.n_blocks ) ) return msg_list def report_ppx(self): """ Returns perplexity of VQ distribution """ if self.apply_vq: # ToDo: support more than 2 groups rt = [vq_layer.report_ppx() for vq_layer in self.vq_layers] + [None] return rt[0], rt[1] else: return None, None def report_usg(self): """ Returns usage of VQ codebook """ if self.apply_vq: # ToDo: support more than 2 groups rt = [vq_layer.report_usg() for vq_layer in self.vq_layers] + [None] return rt[0], rt[1] else: return None, None def get_unmasked_feat(self, sp_seq, n_layer): """ Returns unmasked features from n-th layer ConvBlock """ unmasked_feat = sp_seq.permute(0, 2, 1) # BxTxC -> BxCxT for i in range(self.n_blocks): unmasked_feat = self.blocks[i](unmasked_feat) if i == n_layer: unmasked_feat = unmasked_feat.permute(0, 2, 1) break return unmasked_feat def forward(self, sp_seq, testing=False): # BxTxC -> BxCxT (reversed in Masked ConvBlock) unmasked_feat = sp_seq.permute(0, 2, 1) # Forward through each layer for i in range(self.n_blocks): unmasked_feat = self.blocks[i](unmasked_feat) if self.disable_cross_layer: # Last layer masked feature only if i == (self.n_blocks - 1): feat = self.masked_convs[i](unmasked_feat) else: # Masked feature aggregation masked_feat = self.masked_convs[i](unmasked_feat) if i == 0: feat = masked_feat else: feat = feat + masked_feat # Apply bottleneck and predict spectrogram if self.apply_vq: q_feat = [] offet = 0 for vq_layer, cd in zip(self.vq_layers, self.vq_code_dims): q_f = vq_layer(feat[:, :, offet : offet + cd], testing).output q_feat.append(q_f) offet += cd q_feat = torch.cat(q_feat, dim=-1) pred = self.postnet(q_feat) elif self.apply_ae: feat = self.ae_bottleneck(feat) pred = self.postnet(feat) else: pred = self.postnet(feat) return Output(hidden_states=feat, prediction=pred)