File size: 878 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
class Embedding(Module):
"""Implements a simple lookup table that maps each input integer to a
high-dimensional vector.
Typically used to embed discrete tokens for processing by neural networks.
Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
"""
def __init__(self, num_embeddings: int, dims: int):
super().__init__()
scale = math.sqrt(1 / dims)
self.weight = mx.random.normal((num_embeddings, dims)) * scale
def _extra_repr(self):
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
def __call__(self, x):
return self.weight[x]
|