Spaces:
Running
Running
from typing import Tuple | |
import torch | |
from torch import nn | |
class AlignmentNetwork(torch.nn.Module): | |
"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. | |
:: | |
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment | |
key -> conv1d -> relu -> conv1d -----------------------^ | |
Args: | |
in_query_channels (int): Number of channels in the query network. Defaults to 80. | |
in_key_channels (int): Number of channels in the key network. Defaults to 512. | |
attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. | |
temperature (float): Temperature for the softmax. Defaults to 0.0005. | |
""" | |
def __init__( | |
self, | |
in_query_channels=80, | |
in_key_channels=512, | |
attn_channels=80, | |
temperature=0.0005, | |
): | |
super().__init__() | |
self.temperature = temperature | |
self.softmax = torch.nn.Softmax(dim=3) | |
self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
self.key_layer = nn.Sequential( | |
nn.Conv1d( | |
in_key_channels, | |
in_key_channels * 2, | |
kernel_size=3, | |
padding=1, | |
bias=True, | |
), | |
torch.nn.ReLU(), | |
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), | |
) | |
self.query_layer = nn.Sequential( | |
nn.Conv1d( | |
in_query_channels, | |
in_query_channels * 2, | |
kernel_size=3, | |
padding=1, | |
bias=True, | |
), | |
torch.nn.ReLU(), | |
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), | |
torch.nn.ReLU(), | |
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), | |
) | |
self.init_layers() | |
def init_layers(self): | |
torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) | |
torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) | |
torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) | |
torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) | |
torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear")) | |
def forward( | |
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None | |
) -> Tuple[torch.tensor, torch.tensor]: | |
"""Forward pass of the aligner encoder. | |
Shapes: | |
- queries: :math:`[B, C, T_de]` | |
- keys: :math:`[B, C_emb, T_en]` | |
- mask: :math:`[B, T_de]` | |
Output: | |
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. | |
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. | |
""" | |
key_out = self.key_layer(keys) | |
query_out = self.query_layer(queries) | |
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 | |
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) | |
if attn_prior is not None: | |
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) | |
if mask is not None: | |
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) | |
attn = self.softmax(attn_logp) | |
return attn, attn_logp | |