File size: 1,925 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math
import torch
from torch import nn

from uniperceiver.utils.registry import Registry

POSITION_ENC_REGISTRY = Registry("POSITION_ENC")
POSITION_ENC_REGISTRY.__doc__ = """
Registry for positional encoding
"""

__all__ = ["SinusoidEncoding", "NNEmbeddingEncoding"]

def build_position_encoding(cfg, dim, max_len):
    name = cfg.MODEL.TOKEN_EMBED.POSITION
    return POSITION_ENC_REGISTRY.get(name)(dim, max_len)

@POSITION_ENC_REGISTRY.register()
class SinusoidEncoding(nn.Module):
    def __init__(self, dim, max_len):
        super(SinusoidEncoding, self).__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, dim, 2).float() *
                             -(math.log(max_len * 2.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        if isinstance(x, int):
            return self.pe[:, x]
        else:
            x_size = x.size(1)
            return self.pe[:, :x_size]

@POSITION_ENC_REGISTRY.register()
class NNEmbeddingEncoding(nn.Module):
    def __init__(self, dim, max_len):
        super(NNEmbeddingEncoding, self).__init__()
        self.position_embeddings = nn.Embedding(max_len, dim)

    def forward(self, x, start_time=0):
        if isinstance(x, int):
            position_embeddings = self.position_embeddings(torch.tensor([x], dtype=torch.long).cuda())
        elif isinstance(x, torch.Tensor) and x.dim()==1:
            position_embeddings = self.position_embeddings(x)
        else:
            x_size = x.size(1)
            position_ids = torch.arange(x_size, dtype=torch.long, device=x.device) + start_time
            position_embeddings = self.position_embeddings(position_ids)
        return position_embeddings