File size: 2,208 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.nn as nn
from torch import Tensor

from ..layers import FFTBlock, get_sinusoid_encoding_table


class FS2TransformerDecoder(nn.Module):
    """ A decoder that accepts a list of sequences as input and
    out a sequence as output.  The input and output sequences share the same length

    The input sequence is a list of tensors, which may contain text-embedding, speaker-embeddings.


    """
    def __init__(
            self,
            input_dim: int = 256,  # must ==  decoder output dim
            n_layers: int = 4,
            n_heads: int = 2,
            hidden_dim: int = 256,
            d_inner: int = 1024,
            dropout: float = 0.5,
            max_len: int = 2048,  # for computing position table
    ):
        super(FS2TransformerDecoder, self).__init__()

        self.input_dim = input_dim
        self.input_dim = input_dim

        self.max_len = max_len

        d_k = hidden_dim // n_heads
        d_v = hidden_dim // n_heads

        n_position = max_len + 1

        # self.speaker_fc = nn.Linear(512, 256, bias=False)

        pos_table = get_sinusoid_encoding_table(n_position, input_dim).unsqueeze(0)

        self.position_enc = nn.Parameter(pos_table, requires_grad=False)
        layers = []
        for _ in range(n_layers):
            layer = FFTBlock(hidden_dim, d_inner, n_heads, d_k, d_v, dropout=dropout)
            layers.append(layer)

        self.layers = nn.ModuleList(layers)

    def forward(self, input: Tensor, mask):
        batch_size, seq_len, input_dim = input.shape[0:3]
        if input.shape[1] != seq_len:
            raise ValueError('The input sequences must have the same length')
        if input.shape[1] != seq_len:
            raise ValueError('The input sequences must have the same dimension')

        attn_mask = mask.unsqueeze(1).expand(-1, seq_len, -1)

        if input.shape[1] > self.max_len:
            raise ValueError('inputs.shape[1] > self.max_len')

        pos_embed = self.position_enc[:, :seq_len, :].expand(batch_size, -1, -1)

        output = input + pos_embed
        for layer in self.layers:
            output, dec_slf_attn = layer(output, mask=mask, attn_mask=attn_mask)

        return output