File size: 2,053 Bytes
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math

import torch
import torch.nn as nn


class PhaseModulatedFourierEmbedder(torch.nn.Module):
    def __init__(
        self,
        num_freqs: int,
        input_dim: int = 3,
    ):
        """
        Initializes the PhaseModulatedFourierEmbedder class.
        Args:
            num_freqs (int): The number of frequencies to be used.
            input_dim (int, optional): The dimension of the input. Defaults to 3.
        Attributes:
            weight (torch.nn.Parameter): The weight parameter initialized with random values.
            carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem.
            out_dim (int): The output dimension calculated based on the input dimension and number of frequencies.
        """

        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs)
        )

        # NOTE this is the highest frequency we can get (2 for peaks, 2 for zeros, and 4 for interpolation points), see also https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
        carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
        carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi
        self.register_buffer("carrier", carrier, persistent=False)

        self.out_dim = input_dim * (num_freqs * 2 + 1)

    def forward(self, x):
        """
        Perform the forward pass of the embedder model.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where
                          output_dim = input_dim + 2 * input_dim.
        """

        m = x.float().unsqueeze(-1)
        fm = (m * self.weight).view(*x.shape[:-1], -1)
        pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1)
        embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1)

        return embedding