|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Any, Optional |
|
|
|
import torch |
|
import torch.onnx.operators |
|
from fairseq import utils |
|
from torch import Tensor, nn |
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module): |
|
"""This module produces sinusoidal positional embeddings of any length. |
|
|
|
Padding symbols are ignored. |
|
""" |
|
|
|
def __init__(self, embedding_dim, padding_idx, init_size=1024): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.padding_idx = padding_idx if padding_idx is not None else 0 |
|
self.weights = SinusoidalPositionalEmbedding.get_embedding( |
|
init_size, embedding_dim, padding_idx |
|
) |
|
self.onnx_trace = False |
|
self.register_buffer("_float_tensor", torch.FloatTensor(1)) |
|
self.max_positions = int(1e5) |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|
|
@staticmethod |
|
def get_embedding( |
|
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None |
|
): |
|
"""Build sinusoidal embeddings. |
|
|
|
This matches the implementation in tensor2tensor, but differs slightly |
|
from the description in Section 3.5 of "Attention Is All You Need". |
|
""" |
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
|
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( |
|
1 |
|
) * emb.unsqueeze(0) |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( |
|
num_embeddings, -1 |
|
) |
|
if embedding_dim % 2 == 1: |
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
|
if padding_idx is not None: |
|
emb[padding_idx, :] = 0 |
|
return emb |
|
|
|
def forward( |
|
self, |
|
input, |
|
incremental_state: Optional[Any] = None, |
|
timestep: Optional[Tensor] = None, |
|
positions: Optional[Any] = None, |
|
): |
|
"""Input is expected to be of size [bsz x seqlen].""" |
|
bspair = torch.onnx.operators.shape_as_tensor(input) |
|
bsz, seq_len = bspair[0], bspair[1] |
|
max_pos = self.padding_idx + 1 + seq_len |
|
if self.weights is None or max_pos > self.weights.size(0): |
|
|
|
self.weights = SinusoidalPositionalEmbedding.get_embedding( |
|
max_pos, self.embedding_dim, self.padding_idx |
|
) |
|
self.weights = self.weights.to(self._float_tensor) |
|
|
|
if incremental_state is not None: |
|
|
|
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len |
|
if self.onnx_trace: |
|
return ( |
|
self.weights.index_select(index=self.padding_idx + pos, dim=0) |
|
.unsqueeze(1) |
|
.repeat(bsz, 1, 1) |
|
) |
|
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) |
|
|
|
positions = utils.make_positions( |
|
input, self.padding_idx, onnx_trace=self.onnx_trace |
|
) |
|
if self.onnx_trace: |
|
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) |
|
embedding_shape = torch.cat( |
|
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) |
|
) |
|
embeddings = torch.onnx.operators.reshape_from_tensor_shape( |
|
flat_embeddings, embedding_shape |
|
) |
|
return embeddings |
|
return ( |
|
self.weights.index_select(0, positions.view(-1)) |
|
.view(bsz, seq_len, -1) |
|
.detach() |
|
) |
|
|