File size: 878 Bytes
f14e74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright © 2023 Apple Inc.

import math

import mlx.core as mx
from mlx.nn.layers.base import Module


class Embedding(Module):
    """Implements a simple lookup table that maps each input integer to a
    high-dimensional vector.

    Typically used to embed discrete tokens for processing by neural networks.

    Args:
        num_embeddings (int): How many possible discrete tokens can we embed.
                              Usually called the vocabulary size.
        dims (int): The dimensionality of the embeddings.
    """

    def __init__(self, num_embeddings: int, dims: int):
        super().__init__()
        scale = math.sqrt(1 / dims)
        self.weight = mx.random.normal((num_embeddings, dims)) * scale

    def _extra_repr(self):
        return f"{self.weight.shape[0]}, {self.weight.shape[1]}"

    def __call__(self, x):
        return self.weight[x]