File size: 2,639 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import math
import torch
import torch.nn as nn

from tencentpretrain.utils.constants import *


class SinusoidalposEmbedding(nn.Module):
    """Sinusoidal positional encoding for non-recurrent neural networks.
    Implementation based on "Attention Is All You Need"
    :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
    Args:
       dropout (float): dropout parameter
       dim (int): embedding size
    """

    def __init__(self, args, _):
        super(SinusoidalposEmbedding, self).__init__()

        if "speech" in args.embedding:
            self.max_seq_length = max(args.max_seq_length, args.max_audio_frames)
            self.arrange_sincos_cross = False
        else:
            self.max_seq_length = args.max_seq_length
            self.arrange_sincos_cross = True
        self.emb_size = args.emb_size
        half_dim = self.emb_size // 2   
        value = math.log(10000) / (half_dim - 1)
        half_exp = torch.exp(torch.arange(half_dim, dtype=torch.float) * -value)
        half_mat = torch.arange(self.max_seq_length, dtype=torch.float).unsqueeze(
            1
        ) * half_exp.unsqueeze(0)
        if not self.arrange_sincos_cross: #Same as the implementation of huggingface/transformers, tensor2tensor
            self.emb = torch.cat([torch.sin(half_mat), torch.cos(half_mat)], dim=1).view(
                self.max_seq_length, -1
            )
        else: #Implementation based on "Attention Is All You Need"
            self.emb = torch.zeros(self.max_seq_length, args.emb_size)
            self.emb[:, 0::2] = torch.sin(half_mat)
            self.emb[:, 1::2] = torch.cos(half_mat)
        if self.emb_size % 2 == 1:
            # zero pad
            self.emb = torch.cat([self.emb, torch.zeros(self.max_seq_length, 1)], dim=1)

        self.emb[args.tokenizer.vocab.get(PAD_TOKEN), :] = 0

    def forward(self, src, seg):
        """Embed inputs.
        Args:
            emb (FloatTensor): Sequence of word vectors
                ``(batch_size, seq_len, self.dim)``
            step (int or NoneType): If stepwise (``seq_len = 1``), use
                the encoding for this position.
        """
        if seg is not None:
            batch_size, seq_length = seg.size()
            device = seg.device
            no_pad_num = seg.sum(dim=-1)
        else:
            batch_size, seq_length = src.size()
            device = src.device
            no_pad_num = (src != 0).sum(dim=-1)
        
        emb =  torch.zeros(batch_size, seq_length, self.emb_size)
        for i in range(batch_size):
            emb[i, :no_pad_num[i], :] = self.emb[2: no_pad_num[i]+2]

        return emb.to(device)