nickovchinnikov's picture
Init
9d61c9b
raw
history blame
1.23 kB
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
class Embedding(Module):
r"""Class represents a simple embedding layer but without any learning of the embeddings.
The embeddings are initialized with random values and kept static throughout training (They are parameters, not model's state).
Args:
num_embeddings (int): Size of the dictionary of embeddings, typically size of the vocabulary.
embedding_dim (int): The size of each embedding vector.
Returns:
torch.Tensor: An output tensor resulting from the lookup operation.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
def forward(self, idx: torch.Tensor) -> torch.Tensor:
r"""Forward propagation for the Embedding implementation.
Args:
idx (torch.Tensor): A tensor containing the indices of the embeddings to be accessed.
Returns:
torch.Tensor: An output tensor resulting from the lookup operation.
"""
return F.embedding(idx, self.embeddings)