File size: 3,663 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Tuple

import torch
from torch import nn


class AlignmentNetwork(torch.nn.Module):
    """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention.

    ::

        query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
        key   -> conv1d -> relu -> conv1d -----------------------^

    Args:
        in_query_channels (int): Number of channels in the query network. Defaults to 80.
        in_key_channels (int): Number of channels in the key network. Defaults to 512.
        attn_channels (int): Number of inner channels in the attention layers. Defaults to 80.
        temperature (float): Temperature for the softmax. Defaults to 0.0005.
    """

    def __init__(
        self,
        in_query_channels=80,
        in_key_channels=512,
        attn_channels=80,
        temperature=0.0005,
    ):
        super().__init__()
        self.temperature = temperature
        self.softmax = torch.nn.Softmax(dim=3)
        self.log_softmax = torch.nn.LogSoftmax(dim=3)

        self.key_layer = nn.Sequential(
            nn.Conv1d(
                in_key_channels,
                in_key_channels * 2,
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            torch.nn.ReLU(),
            nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
        )

        self.query_layer = nn.Sequential(
            nn.Conv1d(
                in_query_channels,
                in_query_channels * 2,
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            torch.nn.ReLU(),
            nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
            torch.nn.ReLU(),
            nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
        )

        self.init_layers()

    def init_layers(self):
        torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"))
        torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"))
        torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"))
        torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"))
        torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear"))

    def forward(
        self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
    ) -> Tuple[torch.tensor, torch.tensor]:
        """Forward pass of the aligner encoder.
        Shapes:
            - queries: :math:`[B, C, T_de]`
            - keys: :math:`[B, C_emb, T_en]`
            - mask: :math:`[B, T_de]`
        Output:
            attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
            attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
        """
        key_out = self.key_layer(keys)
        query_out = self.query_layer(queries)
        attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
        attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
        if attn_prior is not None:
            attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8)

        if mask is not None:
            attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))

        attn = self.softmax(attn_logp)
        return attn, attn_logp