File size: 1,226 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F


class Embedding(Module):
    r"""Class represents a simple embedding layer but without any learning of the embeddings.
    The embeddings are initialized with random values and kept static throughout training (They are parameters, not model's state).

    Args:
        num_embeddings (int): Size of the dictionary of embeddings, typically size of the vocabulary.
        embedding_dim (int): The size of each embedding vector.

    Returns:
        torch.Tensor: An output tensor resulting from the lookup operation.
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
    ):
        super().__init__()
        self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        r"""Forward propagation for the Embedding implementation.

        Args:
            idx (torch.Tensor): A tensor containing the indices of the embeddings to be accessed.

        Returns:
            torch.Tensor: An output tensor resulting from the lookup operation.
        """
        return F.embedding(idx, self.embeddings)