Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,443 Bytes
c9724af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import torch.nn as nn
class FrequencyPositionalEmbedding(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(
self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True,
) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
else:
frequencies = torch.linspace(
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies).view(
*x.shape[:-1], -1
)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
|