Spaces:
Runtime error
Runtime error
File size: 2,639 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import math
import torch
import torch.nn as nn
from tencentpretrain.utils.constants import *
class SinusoidalposEmbedding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
Args:
dropout (float): dropout parameter
dim (int): embedding size
"""
def __init__(self, args, _):
super(SinusoidalposEmbedding, self).__init__()
if "speech" in args.embedding:
self.max_seq_length = max(args.max_seq_length, args.max_audio_frames)
self.arrange_sincos_cross = False
else:
self.max_seq_length = args.max_seq_length
self.arrange_sincos_cross = True
self.emb_size = args.emb_size
half_dim = self.emb_size // 2
value = math.log(10000) / (half_dim - 1)
half_exp = torch.exp(torch.arange(half_dim, dtype=torch.float) * -value)
half_mat = torch.arange(self.max_seq_length, dtype=torch.float).unsqueeze(
1
) * half_exp.unsqueeze(0)
if not self.arrange_sincos_cross: #Same as the implementation of huggingface/transformers, tensor2tensor
self.emb = torch.cat([torch.sin(half_mat), torch.cos(half_mat)], dim=1).view(
self.max_seq_length, -1
)
else: #Implementation based on "Attention Is All You Need"
self.emb = torch.zeros(self.max_seq_length, args.emb_size)
self.emb[:, 0::2] = torch.sin(half_mat)
self.emb[:, 1::2] = torch.cos(half_mat)
if self.emb_size % 2 == 1:
# zero pad
self.emb = torch.cat([self.emb, torch.zeros(self.max_seq_length, 1)], dim=1)
self.emb[args.tokenizer.vocab.get(PAD_TOKEN), :] = 0
def forward(self, src, seg):
"""Embed inputs.
Args:
emb (FloatTensor): Sequence of word vectors
``(batch_size, seq_len, self.dim)``
step (int or NoneType): If stepwise (``seq_len = 1``), use
the encoding for this position.
"""
if seg is not None:
batch_size, seq_length = seg.size()
device = seg.device
no_pad_num = seg.sum(dim=-1)
else:
batch_size, seq_length = src.size()
device = src.device
no_pad_num = (src != 0).sum(dim=-1)
emb = torch.zeros(batch_size, seq_length, self.emb_size)
for i in range(batch_size):
emb[i, :no_pad_num[i], :] = self.emb[2: no_pad_num[i]+2]
return emb.to(device)
|