|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class PositionalEncoder(nn.Module): |
|
def __init__(self, d, T=1000, repeat=None, offset=0): |
|
super(PositionalEncoder, self).__init__() |
|
self.d = d |
|
self.T = T |
|
self.repeat = repeat |
|
self.denom = torch.pow( |
|
T, 2 * (torch.arange(offset, offset + d).float() // 2) / d |
|
) |
|
self.updated_location = False |
|
|
|
def forward(self, batch_positions): |
|
if not self.updated_location: |
|
self.denom = self.denom.to(batch_positions.device) |
|
self.updated_location = True |
|
sinusoid_table = ( |
|
batch_positions[:, :, None] / self.denom[None, None, :] |
|
) |
|
sinusoid_table[:, :, 0::2] = torch.sin(sinusoid_table[:, :, 0::2]) |
|
sinusoid_table[:, :, 1::2] = torch.cos(sinusoid_table[:, :, 1::2]) |
|
|
|
if self.repeat is not None: |
|
sinusoid_table = torch.cat( |
|
[sinusoid_table for _ in range(self.repeat)], dim=-1 |
|
) |
|
|
|
return sinusoid_table |
|
|