# coding=utf-8 # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Bert VITS2 model.""" import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union, List import numpy as np import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import ( BaseModelOutput, ModelOutput, ) from transformers.models.bert.modeling_bert import BertModel from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging logger = logging.get_logger(__name__) @dataclass class BertVits2ModelOutput(ModelOutput): """ Describes the outputs for the VITS model, with potential hidden states and attentions. Args: waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): The final audio waveform predicted by the model. sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`): The length in samples of each element in the `waveform` batch. spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi GAN decoder model to obtain the final audio waveform. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attention weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ waveform: torch.FloatTensor = None sequence_lengths: torch.FloatTensor = None spectrogram: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class BertVits2TextEncoderOutput(ModelOutput): """ Describes the outputs for the VITS text encoder model, with potential hidden states and attentions. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The predicted mean values of the prior distribution for the latent text variables. prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The predicted log-variance values of the prior distribution for the latent text variables. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attention weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: torch.FloatTensor = None prior_means: torch.FloatTensor = None prior_log_variances: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): in_act = input_a + input_b t_act = torch.tanh(in_act[:, :num_channels, :]) s_act = torch.sigmoid(in_act[:, num_channels:, :]) acts = t_act * s_act return acts def _unconstrained_rational_quadratic_spline( inputs, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse=False, tail_bound=5.0, min_bin_width=1e-3, min_bin_height=1e-3, min_derivative=1e-3, ): """ This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the `tail_bound`, the transform behaves as an identity function. Args: inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Second half of the hidden-states input to the Vits convolutional flow module. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module reverse (`bool`, *optional*, defaults to `False`): Whether the model is being run in reverse mode. tail_bound (`float`, *optional* defaults to 5): Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the transform behaves as an identity function. min_bin_width (`float`, *optional*, defaults to 1e-3): Minimum bin value across the width dimension for the piecewise rational quadratic function. min_bin_height (`float`, *optional*, defaults to 1e-3): Minimum bin value across the height dimension for the piecewise rational quadratic function. min_derivative (`float`, *optional*, defaults to 1e-3): Minimum bin value across the derivatives for the piecewise rational quadratic function. Returns: outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits applied. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound` limits applied. """ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask outputs = torch.zeros_like(inputs) log_abs_det = torch.zeros_like(inputs) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant outputs[outside_interval_mask] = inputs[outside_interval_mask] log_abs_det[outside_interval_mask] = 0.0 outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], reverse=reverse, tail_bound=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, ) return outputs, log_abs_det def _rational_quadratic_spline( inputs, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse, tail_bound, min_bin_width, min_bin_height, min_derivative, ): """ This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`. Args: inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Second half of the hidden-states input to the Vits convolutional flow module. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module reverse (`bool`): Whether the model is being run in reverse mode. tail_bound (`float`): Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the transform behaves as an identity function. min_bin_width (`float`): Minimum bin value across the width dimension for the piecewise rational quadratic function. min_bin_height (`float`): Minimum bin value across the height dimension for the piecewise rational quadratic function. min_derivative (`float`): Minimum bin value across the derivatives for the piecewise rational quadratic function. Returns: outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Hidden-states as transformed by the piecewise rational quadratic function. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Logarithm of the absolute value of the determinants corresponding to the `outputs`. """ upper_bound = tail_bound lower_bound = -tail_bound if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound: raise ValueError("Input to a transform is not within its domain") num_bins = unnormalized_widths.shape[-1] if min_bin_width * num_bins > 1.0: raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}") if min_bin_height * num_bins > 1.0: raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}") widths = nn.functional.softmax(unnormalized_widths, dim=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths cumwidths = torch.cumsum(widths, dim=-1) cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound cumwidths[..., 0] = lower_bound cumwidths[..., -1] = upper_bound widths = cumwidths[..., 1:] - cumwidths[..., :-1] derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives) heights = nn.functional.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights cumheights = torch.cumsum(heights, dim=-1) cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) cumheights = (upper_bound - lower_bound) * cumheights + lower_bound cumheights[..., 0] = lower_bound cumheights[..., -1] = upper_bound heights = cumheights[..., 1:] - cumheights[..., :-1] bin_locations = cumheights if reverse else cumwidths bin_locations[..., -1] += 1e-6 bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 bin_idx = bin_idx[..., None] input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] input_bin_widths = widths.gather(-1, bin_idx)[..., 0] input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] delta = heights / widths input_delta = delta.gather(-1, bin_idx)[..., 0] input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0] intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta if not reverse: theta = (inputs - input_cumwidths) / input_bin_widths theta_one_minus_theta = theta * (1 - theta) numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) denominator = input_delta + intermediate1 * theta_one_minus_theta outputs = input_cumheights + numerator / denominator derivative_numerator = input_delta.pow(2) * ( input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2) ) log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, log_abs_det else: # find the roots of a quadratic equation intermediate2 = inputs - input_cumheights intermediate3 = intermediate2 * intermediate1 a = input_heights * (input_delta - input_derivatives) + intermediate3 b = input_heights * input_derivatives - intermediate3 c = -input_delta * intermediate2 discriminant = b.pow(2) - 4 * a * c if not (discriminant >= 0).all(): raise RuntimeError(f"invalid discriminant {discriminant}") root = (2 * c) / (-b - torch.sqrt(discriminant)) outputs = root * input_bin_widths + input_cumwidths theta_one_minus_theta = root * (1 - root) denominator = input_delta + intermediate1 * theta_one_minus_theta derivative_numerator = input_delta.pow(2) * ( input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2) ) log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, -log_abs_det class BertVits2WaveNet(torch.nn.Module): def __init__(self, config, num_layers: int): super().__init__() self.hidden_size = config.hidden_size self.num_layers = num_layers self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(config.wavenet_dropout) # if hasattr(nn.utils.parametrizations, "weight_norm"): # weight_norm = nn.utils.parametrizations.weight_norm # else: weight_norm = nn.utils.weight_norm if config.speaker_embedding_size != 0: cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) self.cond_layer = weight_norm(cond_layer, name="weight") for i in range(num_layers): dilation = config.wavenet_dilation_rate**i padding = (config.wavenet_kernel_size * dilation - dilation) // 2 in_layer = torch.nn.Conv1d( in_channels=config.hidden_size, out_channels=2 * config.hidden_size, kernel_size=config.wavenet_kernel_size, dilation=dilation, padding=padding, ) in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary if i < num_layers - 1: res_skip_channels = 2 * config.hidden_size else: res_skip_channels = config.hidden_size res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, inputs, padding_mask, global_conditioning=None): outputs = torch.zeros_like(inputs) num_channels_tensor = torch.IntTensor([self.hidden_size]) if global_conditioning is not None: global_conditioning = self.cond_layer(global_conditioning) for i in range(self.num_layers): hidden_states = self.in_layers[i](inputs) if global_conditioning is not None: cond_offset = i * 2 * self.hidden_size global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :] else: global_states = torch.zeros_like(hidden_states) acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) acts = self.dropout(acts) res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: res_acts = res_skip_acts[:, : self.hidden_size, :] inputs = (inputs + res_acts) * padding_mask outputs = outputs + res_skip_acts[:, self.hidden_size :, :] else: outputs = outputs + res_skip_acts return outputs * padding_mask def remove_weight_norm(self): if self.speaker_embedding_size != 0: torch.nn.utils.remove_weight_norm(self.cond_layer) for layer in self.in_layers: torch.nn.utils.remove_weight_norm(layer) for layer in self.res_skip_layers: torch.nn.utils.remove_weight_norm(layer) class BertVits2PosteriorEncoder(nn.Module): def __init__(self, config): super().__init__() self.out_channels = config.flow_size self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1) self.wavenet = BertVits2WaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers) self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1) def forward(self, inputs, padding_mask, global_conditioning=None): inputs = self.conv_pre(inputs) * padding_mask inputs = self.wavenet(inputs, padding_mask, global_conditioning) stats = self.conv_proj(inputs) * padding_mask mean, log_stddev = torch.split(stats, self.out_channels, dim=1) sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask return sampled, mean, log_stddev # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock class HifiGanResidualBlock(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): super().__init__() self.leaky_relu_slope = leaky_relu_slope self.convs1 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=dilation[i], padding=self.get_padding(kernel_size, dilation[i]), ) for i in range(len(dilation)) ] ) self.convs2 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=self.get_padding(kernel_size, 1), ) for _ in range(len(dilation)) ] ) def get_padding(self, kernel_size, dilation=1): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): for layer in self.convs1: nn.utils.weight_norm(layer) for layer in self.convs2: nn.utils.weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: nn.utils.remove_weight_norm(layer) for layer in self.convs2: nn.utils.remove_weight_norm(layer) def forward(self, hidden_states): for conv1, conv2 in zip(self.convs1, self.convs2): residual = hidden_states hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv1(hidden_states) hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv2(hidden_states) hidden_states = hidden_states + residual return hidden_states class BertVits2HifiGan(nn.Module): def __init__(self, config): super().__init__() self.config = config self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.conv_pre = nn.Conv1d( config.flow_size, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3, ) self.upsampler = nn.ModuleList() for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): self.upsampler.append( nn.ConvTranspose1d( config.upsample_initial_channel // (2**i), config.upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel_size, stride=upsample_rate, padding=(kernel_size - upsample_rate) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.upsampler)): channels = config.upsample_initial_channel // (2 ** (i + 1)) for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) def apply_weight_norm(self): for layer in self.upsampler: nn.utils.weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() def remove_weight_norm(self): for layer in self.upsampler: nn.utils.remove_weight_norm(layer) for layer in self.resblocks: layer.remove_weight_norm() def forward( self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: r""" Converts a spectrogram into a speech waveform. Args: spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`): Tensor containing the spectrograms. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*): Tensor containing speaker embeddings, for multispeaker models. Returns: `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform. """ hidden_states = self.conv_pre(spectrogram) if global_conditioning is not None: hidden_states = hidden_states + self.cond(global_conditioning) for i in range(self.num_upsamples): hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) hidden_states = self.upsampler[i](hidden_states) res_state = self.resblocks[i * self.num_kernels](hidden_states) for j in range(1, self.num_kernels): res_state += self.resblocks[i * self.num_kernels + j](hidden_states) hidden_states = res_state / self.num_kernels hidden_states = nn.functional.leaky_relu(hidden_states) hidden_states = self.conv_post(hidden_states) waveform = torch.tanh(hidden_states) return waveform class BertVits2ResidualCouplingLayer(nn.Module): def __init__(self, config): super().__init__() self.half_channels = config.flow_size // 2 self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) self.wavenet = BertVits2WaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) def forward(self, inputs, padding_mask, global_conditioning=None): first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) * padding_mask hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) mean = self.conv_post(hidden_states) * padding_mask log_stddev = torch.zeros_like(mean) second_half = mean + second_half * torch.exp(log_stddev) * padding_mask outputs = torch.cat([first_half, second_half], dim=1) log_determinant = torch.sum(log_stddev, [1, 2]) return outputs, log_determinant class BertVits2ResidualCouplingBlock(nn.Module): def __init__(self, config): super().__init__() self.flows = nn.ModuleList() for _ in range(config.prior_encoder_num_flows): self.flows.append(BertVits2ResidualCouplingLayer(config)) def forward(self, inputs, padding_mask, global_conditioning=None): x = inputs for flow in self.flows: x, _ = flow(x, padding_mask, global_conditioning) x = torch.flip(x, [1]) return x class BertVits2TransformerCouplingLayer(nn.Module): def __init__(self, config): super().__init__() self.half_channels = config.flow_size // 2 self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) self.encoder = BertVits2Encoder( config, kernel_size=5, n_layers=config.prior_encoder_num_flows_layers, ) self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) def forward( self, inputs, padding_mask, global_conditioning=None, reverse=False, return_dict=True, ): inputs1, inputs2 = torch.split(inputs, [self.half_channels] * 2, 1) hidden_state = self.conv_pre(inputs1) * padding_mask hidden_state = self.encoder( hidden_states=hidden_state.transpose(1, 2), padding_mask=padding_mask.transpose(1, 2), global_conditioning=global_conditioning, return_dict=return_dict ) hidden_state = hidden_state.last_hidden_state if return_dict else hidden_state[0] hidden_state = hidden_state.transpose(1, 2) hidden_state = self.conv_post(hidden_state) * padding_mask logs = torch.zeros_like(hidden_state) if not reverse: inputs1 = hidden_state + inputs1 * torch.exp(logs) * padding_mask x = torch.cat([inputs1, inputs2], 1) logdet = torch.sum(logs, [1, 2]) return x, logdet else: inputs2 = (inputs2 - hidden_state) * torch.exp(-logs) * padding_mask x = torch.cat([inputs1, inputs2], 1) return x, None class BertVits2TransformerCouplingBlock(nn.Module): def __init__(self, config): super().__init__() self.flows = nn.ModuleList([ BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows) ]) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): if not reverse: for flow in self.flows: inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=False) inputs = torch.flip(inputs, [1]) else: for flow in reversed(self.flows): inputs = torch.flip(inputs, [1]) inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True) return inputs class BertVits2DilatedDepthSeparableConv(nn.Module): def __init__(self, config, dropout_rate=0.0): super().__init__() kernel_size = config.duration_predictor_kernel_size channels = config.hidden_size self.num_layers = config.depth_separable_num_layers self.dropout = nn.Dropout(dropout_rate) self.convs_dilated = nn.ModuleList() self.convs_pointwise = nn.ModuleList() self.norms_1 = nn.ModuleList() self.norms_2 = nn.ModuleList() for i in range(self.num_layers): dilation = kernel_size**i padding = (kernel_size * dilation - dilation) // 2 self.convs_dilated.append( nn.Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels, dilation=dilation, padding=padding, ) ) self.convs_pointwise.append(nn.Conv1d(channels, channels, 1)) self.norms_1.append(nn.LayerNorm(channels)) self.norms_2.append(nn.LayerNorm(channels)) def forward(self, inputs, padding_mask, global_conditioning=None): if global_conditioning is not None: inputs = inputs + global_conditioning for i in range(self.num_layers): hidden_states = self.convs_dilated[i](inputs * padding_mask) hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1) hidden_states = nn.functional.gelu(hidden_states) hidden_states = self.convs_pointwise[i](hidden_states) hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1) hidden_states = nn.functional.gelu(hidden_states) hidden_states = self.dropout(hidden_states) inputs = inputs + hidden_states return inputs * padding_mask class BertVits2ConvFlow(nn.Module): def __init__(self, config): super().__init__() self.filter_channels = config.hidden_size self.half_channels = config.depth_separable_channels // 2 self.num_bins = config.duration_predictor_flow_bins self.tail_bound = config.duration_predictor_tail_bound self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1) self.conv_dds = BertVits2DilatedDepthSeparableConv(config) self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning) hidden_states = self.conv_proj(hidden_states) * padding_mask batch_size, channels, length = first_half.shape hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2) unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels) unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :] second_half, log_abs_det = _unconstrained_rational_quadratic_spline( second_half, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse=reverse, tail_bound=self.tail_bound, ) outputs = torch.cat([first_half, second_half], dim=1) * padding_mask if not reverse: log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2]) return outputs, log_determinant else: return outputs, None class BertVits2ElementwiseAffine(nn.Module): def __init__(self, config): super().__init__() self.channels = config.depth_separable_channels self.translate = nn.Parameter(torch.zeros(self.channels, 1)) self.log_scale = nn.Parameter(torch.zeros(self.channels, 1)) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): if not reverse: outputs = self.translate + torch.exp(self.log_scale) * inputs outputs = outputs * padding_mask log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2]) return outputs, log_determinant else: outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask return outputs, None class BertVits2StochasticDurationPredictor(nn.Module): def __init__(self, config): super().__init__() embed_dim = config.speaker_embedding_size filter_channels = config.hidden_size self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1) self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.conv_dds = BertVits2DilatedDepthSeparableConv( config, dropout_rate=config.duration_predictor_dropout, ) if embed_dim != 0: self.cond = nn.Conv1d(embed_dim, filter_channels, 1) self.flows = nn.ModuleList() self.flows.append(BertVits2ElementwiseAffine(config)) for _ in range(config.duration_predictor_num_flows): self.flows.append(BertVits2ConvFlow(config)) self.post_conv_pre = nn.Conv1d(1, filter_channels, 1) self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_conv_dds = BertVits2DilatedDepthSeparableConv( config, dropout_rate=config.duration_predictor_dropout, ) self.post_flows = nn.ModuleList() self.post_flows.append(BertVits2ElementwiseAffine(config)) for _ in range(config.duration_predictor_num_flows): self.post_flows.append(BertVits2ConvFlow(config)) def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0): inputs = torch.detach(inputs) inputs = self.conv_pre(inputs) if global_conditioning is not None: global_conditioning = torch.detach(global_conditioning) inputs = inputs + self.cond(global_conditioning) inputs = self.conv_dds(inputs, padding_mask) inputs = self.conv_proj(inputs) * padding_mask if not reverse: hidden_states = self.post_conv_pre(durations) hidden_states = self.post_conv_dds(hidden_states, padding_mask) hidden_states = self.post_conv_proj(hidden_states) * padding_mask random_posterior = ( torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype) * padding_mask ) log_determinant_posterior_sum = 0 latents_posterior = random_posterior for flow in self.post_flows: latents_posterior, log_determinant = flow( latents_posterior, padding_mask, global_conditioning=inputs + hidden_states ) latents_posterior = torch.flip(latents_posterior, [1]) log_determinant_posterior_sum += log_determinant first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1) log_determinant_posterior_sum += torch.sum( (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2] ) logq = ( torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2]) - log_determinant_posterior_sum ) first_half = (durations - torch.sigmoid(first_half)) * padding_mask first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask log_determinant_sum = torch.sum(-first_half, [1, 2]) latents = torch.cat([first_half, second_half], dim=1) for flow in self.flows: latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs) latents = torch.flip(latents, [1]) log_determinant_sum += log_determinant nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum return nll + logq else: flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] # remove a useless vflow latents = ( torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype) * noise_scale ) for flow in flows: latents = torch.flip(latents, [1]) latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True) log_duration, _ = torch.split(latents, [1, 1], dim=1) return log_duration class BertVits2DurationPredictor(nn.Module): def __init__(self, config): super().__init__() kernel_size = config.duration_predictor_kernel_size filter_channels = config.duration_predictor_filter_channels self.dropout = nn.Dropout(config.duration_predictor_dropout) self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) self.proj = nn.Conv1d(filter_channels, 1, 1) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1) def forward(self, inputs, padding_mask, global_conditioning=None): inputs = torch.detach(inputs) if global_conditioning is not None: global_conditioning = torch.detach(global_conditioning) inputs = inputs + self.cond(global_conditioning) inputs = self.conv_1(inputs * padding_mask) inputs = torch.relu(inputs) inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1) inputs = self.dropout(inputs) inputs = self.conv_2(inputs * padding_mask) inputs = torch.relu(inputs) inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1) inputs = self.dropout(inputs) inputs = self.proj(inputs * padding_mask) return inputs * padding_mask class BertVits2Attention(nn.Module): """Multi-headed attention with relative positional representation.""" def __init__(self, config): super().__init__() self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.dropout = config.attention_dropout self.window_size = config.window_size self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}" f" and `num_attention_heads`: {self.num_heads})." ) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) if self.window_size: self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len) relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1)) rel_pos_bias = self._relative_position_to_absolute_position(relative_logits) attn_weights += rel_pos_bias if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) if self.window_size is not None: value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len) relative_weights = self._absolute_position_to_relative_position(attn_probs) rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings) attn_output += rel_pos_bias attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped def _get_relative_embeddings(self, relative_embeddings, length): pad_length = max(length - (self.window_size + 1), 0) if pad_length > 0: relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) slice_start_position = max((self.window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 return relative_embeddings[:, slice_start_position:slice_end_position] def _relative_position_to_absolute_position(self, x): batch_heads, length, _ = x.size() # Concat columns of pad to shift from relative to absolute indexing. x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0]) # Concat extra elements so to add up to shape (len+1, 2*len-1). x_flat = x.view([batch_heads, length * 2 * length]) x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0]) # Reshape and slice out the padded elements. x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1]) x_final = x_final[:, :length, length - 1 :] return x_final def _absolute_position_to_relative_position(self, x): batch_heads, length, _ = x.size() # Pad along column x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0]) x_flat = x.view([batch_heads, length * (2 * length - 1)]) # Add 0's in the beginning that will skew the elements after reshape x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0]) x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:] return x_final class BertVits2FeedForward(nn.Module): def __init__(self, config, kernel_size=None): super().__init__() if kernel_size is None: kernel_size = config.ffn_kernel_size self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, kernel_size) self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, kernel_size) self.dropout = nn.Dropout(config.activation_dropout) if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act if kernel_size > 1: pad_left = (kernel_size - 1) // 2 pad_right = kernel_size // 2 self.padding = [pad_left, pad_right, 0, 0, 0, 0] else: self.padding = None def forward(self, hidden_states, padding_mask): hidden_states = hidden_states.permute(0, 2, 1) padding_mask = padding_mask.permute(0, 2, 1) hidden_states = hidden_states * padding_mask if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_1(hidden_states) hidden_states = self.act_fn(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states * padding_mask if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_2(hidden_states) hidden_states = hidden_states * padding_mask hidden_states = hidden_states.permute(0, 2, 1) return hidden_states class BertVits2EncoderLayer(nn.Module): def __init__(self, config, kernel_size=None): super().__init__() self.attention = BertVits2Attention(config) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = BertVits2FeedForward(config, kernel_size=kernel_size) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ): residual = hidden_states hidden_states, attn_weights = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = self.layer_norm(residual + hidden_states) residual = hidden_states hidden_states = self.feed_forward(hidden_states, padding_mask) hidden_states = self.dropout(hidden_states) hidden_states = self.final_layer_norm(residual + hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class BertVits2Encoder(nn.Module): def __init__(self, config, kernel_size=None, n_layers=None): super().__init__() self.config = config if n_layers is None: n_layers = config.num_hidden_layers self.speaker_embed_proj = nn.Linear(config.speaker_embedding_size, config.hidden_size) self.layers = nn.ModuleList([BertVits2EncoderLayer(config, kernel_size=kernel_size) for _ in range(n_layers)]) self.gradient_checkpointing = False self.layerdrop = config.layerdrop self.conditioning_layer_index = config.conditioning_layer_index def forward( self, hidden_states: torch.FloatTensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, global_conditioning: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) hidden_states = hidden_states * padding_mask deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() for i, encoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) if i == self.conditioning_layer_index and global_conditioning is not None: global_conditioning = self.speaker_embed_proj(global_conditioning.transpose(1, 2)) hidden_states = hidden_states + global_conditioning hidden_states = hidden_states * padding_mask skip_the_layer = self.training and (dropout_probability < self.layerdrop) if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, padding_mask, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, padding_mask=padding_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) hidden_states = hidden_states * padding_mask if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class BertVits2TextEncoder(nn.Module): """ Transformer encoder that uses relative positional representation instead of absolute positional encoding. """ def __init__(self, config): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) nn.init.normal_(self.embed_tokens.weight, 0.0, config.hidden_size**-0.5) self.embed_tones = nn.Embedding(config.num_tones, config.hidden_size) nn.init.normal_(self.embed_tones.weight, 0.0, config.hidden_size**-0.5) self.embed_languages = nn.Embedding(config.num_languages, config.hidden_size) nn.init.normal_(self.embed_languages.weight, 0.0, config.hidden_size**-0.5) self.bert_projs = nn.ModuleList() for bert in config.bert_configs: self.bert_projs.append(nn.Conv1d(bert.hidden_size, config.hidden_size, 1)) self.encoder = BertVits2Encoder(config) self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.Tensor, tone_ids: torch.Tensor, language_ids: torch.Tensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, bert_embeddings: Optional[List[torch.Tensor]] = None, global_conditioning: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor], BertVits2TextEncoderOutput]: x = self.embed_tokens(input_ids) x = x + self.embed_tones(tone_ids) x = x + self.embed_languages(language_ids) for project, inputs in zip(self.bert_projs, bert_embeddings): x = x + project(inputs).transpose(1, 2) hidden_states = x * math.sqrt(self.config.hidden_size) encoder_outputs = self.encoder( hidden_states=hidden_states, padding_mask=padding_mask, attention_mask=attention_mask, global_conditioning=global_conditioning, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2) if not return_dict: outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:] return outputs return BertVits2TextEncoderOutput( last_hidden_state=last_hidden_state, prior_means=prior_means, prior_log_variances=prior_log_variances, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class BertVits2ReferenceEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config ref_enc_filters = [32, 32, 64, 64, 128, 128] K = len(ref_enc_filters) filters = [1] + ref_enc_filters self.convs = nn.ModuleList([ nn.utils.weight_norm( nn.Conv2d( in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ) ) for i in range(K) ]) out_channels = self.calculate_channels(config.spectrogram_bins, 3, 2, 1, K) self.gru = nn.GRU( input_size=ref_enc_filters[-1] * out_channels, hidden_size=256 // 2, batch_first=True, ) self.proj = nn.Linear(128, self.config.speaker_embedding_size) def forward(self, input_ids, attention_mask): N = input_ids.size(0) out = input_ids.view(N, 1, -1, self.config.spectrogram_bins) for conv in self.convs: out = conv(out) # out = wn(out) out = nn.functional.relu(out) out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] T = out.size(1) N = out.size(0) out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] self.gru.flatten_parameters() _, out = self.gru(out) # out --- [1, N, 128] return self.proj(out.squeeze(0)) def calculate_channels(self, L, kernel_size, stride, pad, n_convs): for i in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L class BertVits2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ base_model_prefix = "vits" main_input_name = "input_ids" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # drop config_class for BertVits2PreTrainedModel del BertVits2PreTrainedModel.config_class class BertVits2Model(BertVits2PreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.text_encoder = BertVits2TextEncoder(config) self.decoder = BertVits2HifiGan(config) self.bert_encoders = nn.ModuleList([BertModel(bert_config) for bert_config in config.bert_configs]) self.bert_proj = nn.ModuleList([nn.Linear(bert_config.hidden_size, config.hidden_size) for bert_config in config.bert_configs]) self.stochastic_duration_predictor = BertVits2StochasticDurationPredictor(config) self.duration_predictor = BertVits2DurationPredictor(config) if config.num_speakers > 1: self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size) # This is used only for training. self.posterior_encoder = BertVits2PosteriorEncoder(config) if config.use_transformer_flow: self.flow = BertVits2TransformerCouplingBlock(config) else: self.flow = BertVits2ResidualCouplingBlock(config) # These parameters control the synthesised speech properties self.speaking_rate = config.speaking_rate self.noise_scale = config.noise_scale self.noise_scale_duration = config.noise_scale_duration self.stochastic_duration_prediction_ratio = config.stochastic_duration_prediction_ratio # Initialize weights and apply final processing self.post_init() def get_encoder(self): return self.text_encoder def forward( self, input_ids: Optional[torch.Tensor] = None, tone_ids: Optional[torch.Tensor] = None, language_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, word_to_phoneme: Optional[torch.Tensor] = None, bert_input_ids: Optional[torch.Tensor] = None, bert_attention_mask: Optional[torch.Tensor] = None, language_id: Optional[int] = None, speaker_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.FloatTensor] = None, ) -> Union[Tuple[Any], BertVits2ModelOutput]: r""" labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*): Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss computation. Returns: Example: ```python >>> from transformers import BertVits2Tokenizer, BertVits2Model, set_seed >>> import torch >>> tokenizer = BertVits2Tokenizer.from_pretrained("facebook/mms-tts-eng") >>> model = BertVits2Model.from_pretrained("facebook/mms-tts-eng") >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt") >>> set_seed(555) # make deterministic >>> with torch.no_grad(): ... outputs = model(inputs["input_ids"]) >>> outputs.waveform.shape torch.Size([1, 45824]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = input_ids.shape[0] if labels is not None: raise NotImplementedError("Training of VITS is not supported yet.") if attention_mask is not None: input_padding_mask = attention_mask.unsqueeze(-1).float() else: input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float() if self.config.num_speakers > 1 and speaker_id is not None: if not 0 <= speaker_id < self.config.num_speakers: raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") if isinstance(speaker_id, int): speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) else: speaker_embeddings = None if language_id is None: language_id = 0 if language_ids is None: language_ids = torch.full_like(input_ids, language_id) phone_len = input_ids.shape[1] is_tuple = isinstance(bert_input_ids, tuple) bert_embeddings = [ self.bert_features(i, bert_input_ids, bert_attention_mask, word_to_phoneme) if i == language_id and not is_tuple else torch.zeros(batch_size, enc.config.hidden_size, phone_len, device=self.device) for i, enc in enumerate(self.bert_encoders) ] text_encoder_output = self.text_encoder( input_ids=input_ids, tone_ids=tone_ids, language_ids=language_ids, padding_mask=input_padding_mask, attention_mask=attention_mask, bert_embeddings=bert_embeddings, global_conditioning=speaker_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state hidden_states = hidden_states.transpose(1, 2) input_padding_mask = input_padding_mask.transpose(1, 2) prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances log_duration = \ self.stochastic_duration_predictor( hidden_states, input_padding_mask, global_conditioning=speaker_embeddings, reverse=True, noise_scale=self.noise_scale_duration, ) * self.stochastic_duration_prediction_ratio + \ self.duration_predictor( hidden_states, input_padding_mask, global_conditioning=speaker_embeddings ) * (1.0 - self.stochastic_duration_prediction_ratio) length_scale = 1.0 / self.speaking_rate duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() # Create a padding mask for the output lengths of shape (batch, 1, max_output_length) indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length) attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) batch_size, _, output_length, input_length = attn_mask.shape cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) valid_indices = indices.unsqueeze(0) < cum_duration valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask # Expand prior distribution prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale latents = self.flow(prior_latents, output_padding_mask, global_conditioning=speaker_embeddings, reverse=True) spectrogram = latents * output_padding_mask waveform = self.decoder(spectrogram, global_conditioning=speaker_embeddings) waveform = waveform.squeeze(1) sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates) if not return_dict: outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:] return outputs return BertVits2ModelOutput( waveform=waveform, sequence_lengths=sequence_lengths, spectrogram=spectrogram, hidden_states=text_encoder_output.hidden_states, attentions=text_encoder_output.attentions, ) def bert_features(self, index, input_ids, attention_mask, word2phone): is_tuple = isinstance(input_ids, tuple) if is_tuple: input_ids = input_ids[index] attention_mask = attention_mask[index] bert_model = self.bert_encoders[index] features = bert_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states x = torch.cat(features[-3:-2], dim=-1) batch_size, _, hidden_dim = x.shape x = x.flatten(0, 1) w2p_flattened = word2phone.flatten() phone_level_feature = x.repeat_interleave(w2p_flattened, dim=0) return phone_level_feature.reshape(batch_size, -1, hidden_dim).transpose(1, 2)