from typing import Optional, Tuple import numpy as np import torch from torch import Tensor, nn def maximum_path( value: Tensor, mask: Tensor, max_neg_val: Optional[float] = None, ): """Monotonic alignment search algorithm Numpy-friendly version. It's about 4 times faster than torch version. value: [b, t_x, t_y] mask: [b, t_x, t_y] """ if max_neg_val is None: max_neg_val = -np.inf # Patch for Sphinx complaint value = value * mask device = value.device dtype = value.dtype value = value.cpu().detach().numpy() mask = mask.cpu().detach().numpy().astype(bool) b, t_x, t_y = value.shape direction = np.zeros(value.shape, dtype=np.int64) v = np.zeros((b, t_x), dtype=np.float32) x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) for j in range(t_y): v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[ :, :-1, ] v1 = v max_mask = v1 >= v0 v_max = np.where(max_mask, v1, v0) direction[:, :, j] = max_mask index_mask = x_range <= j v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) direction = np.where(mask, direction, 1) path = np.zeros(value.shape, dtype=np.float32) index = mask[:, :, 0].sum(1).astype(np.int64) - 1 # type: ignore index_range = np.arange(b) for j in reversed(range(t_y)): path[index_range, index, j] = 1 index = index + direction[index_range, index, j] - 1 path = path * mask.astype(np.float32) # type: ignore path = torch.from_numpy(path).to(device=device, dtype=dtype) return path class AlignmentNetwork(torch.nn.Module): r"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. The architecture is as follows: query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment key -> conv1d -> relu -> conv1d -----------------------^ Args: in_query_channels (int): Number of channels in the query network. in_key_channels (int): Number of channels in the key network. attn_channels (int): Number of inner channels in the attention layers. temperature (float, optional): Temperature for the softmax. Defaults to 0.0005. """ def __init__( self, in_query_channels: int, in_key_channels: int, attn_channels: int, temperature: float = 0.0005, ): super().__init__() self.temperature = temperature self.softmax = torch.nn.Softmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3) self.key_layer = nn.Sequential( nn.Conv1d( in_key_channels, in_key_channels * 2, kernel_size=3, padding=1, bias=True, ), torch.nn.ReLU(), nn.Conv1d( in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True, ), ) self.query_layer = nn.Sequential( nn.Conv1d( in_query_channels, in_query_channels * 2, kernel_size=3, padding=1, bias=True, ), torch.nn.ReLU(), nn.Conv1d( in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True, ), torch.nn.ReLU(), nn.Conv1d( in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True, ), ) self.init_layers() def init_layers(self): r"""Initialize the weights of the key and query layers using Xavier uniform initialization. The gain is calculated based on the activation function: ReLU for the first layer and linear for the rest. """ torch.nn.init.xavier_uniform_( self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"), ) torch.nn.init.xavier_uniform_( self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"), ) torch.nn.init.xavier_uniform_( self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"), ) torch.nn.init.xavier_uniform_( self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"), ) torch.nn.init.xavier_uniform_( self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear"), ) def _forward_aligner( self, queries: Tensor, keys: Tensor, mask: Optional[Tensor] = None, attn_prior: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: r"""Forward pass of the aligner encoder. Args: queries (Tensor): Input queries of shape [B, C, T_de]. keys (Tensor): Input keys of shape [B, C_emb, T_en]. mask (Optional[Tensor], optional): Mask of shape [B, T_de]. Defaults to None. attn_prior (Optional[Tensor], optional): Prior attention tensor. Defaults to None. Returns: Tuple[Tensor, Tensor]: A tuple containing the soft attention mask of shape [B, 1, T_en, T_de] and log probabilities of shape [B, 1, T_en , T_de]. """ key_out = self.key_layer(keys) query_out = self.query_layer(queries) attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) if attn_prior is not None: attn_logp = self.log_softmax(attn_logp) + torch.log( attn_prior[:, None] + 1e-8, ).permute((0, 1, 3, 2)) if mask is not None: attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) attn = self.softmax(attn_logp) return attn, attn_logp def forward( self, x: Tensor, y: Tensor, x_mask: Tensor, y_mask: Tensor, attn_priors: Tensor, ) -> Tuple[ Tensor, Tensor, Tensor, Tensor, ]: r"""Aligner forward pass. 1. Compute a mask to apply to the attention map. 2. Run the alignment network. 3. Apply MAS to compute the hard alignment map. 4. Compute the durations from the hard alignment map. Args: x (torch.Tensor): Input sequence. y (torch.Tensor): Output sequence. x_mask (torch.Tensor): Input sequence mask. y_mask (torch.Tensor): Output sequence mask. attn_priors (torch.Tensor): Prior for the aligner network map. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, hard alignment map. Shapes: - x: :math:`[B, T_en, C_en]` - y: :math:`[B, T_de, C_de]` - x_mask: :math:`[B, 1, T_en]` - y_mask: :math:`[B, 1, T_de]` - attn_priors: :math:`[B, T_de, T_en]` - aligner_durations: :math:`[B, T_en]` - aligner_soft: :math:`[B, T_de, T_en]` - aligner_logprob: :math:`[B, 1, T_de, T_en]` - aligner_mas: :math:`[B, T_de, T_en]` """ # [B, 1, T_en, T_de] attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) aligner_soft, aligner_logprob = self._forward_aligner( y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors, ) aligner_mas = maximum_path( aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous(), ) aligner_durations = torch.sum(aligner_mas, -1).int() # [B, T_max2, T_max] aligner_soft = aligner_soft.squeeze(1) # [B, T_max, T_max2] -> [B, T_max2, T_max] aligner_mas = aligner_mas.transpose(1, 2) return aligner_logprob, aligner_soft, aligner_mas, aligner_durations