|
|
|
|
|
import mlx.core as mx |
|
from mlx.nn.layers.base import Module |
|
|
|
|
|
class Dropout(Module): |
|
r"""Randomly zero a portion of the elements during training. |
|
|
|
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where |
|
:math:`p` is the probability of zeroing an element. This is done so the |
|
expected value of a given element will remain the same. |
|
|
|
Args: |
|
p (float): The probability to zero an element |
|
""" |
|
|
|
def __init__(self, p: float = 0.5): |
|
super().__init__() |
|
|
|
if p < 0 or p >= 1: |
|
raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
|
self._p_1 = 1 - p |
|
|
|
def _extra_repr(self): |
|
return f"p={1-self._p_1}" |
|
|
|
def __call__(self, x): |
|
if self._p_1 == 1 or not self.training: |
|
return x |
|
|
|
mask = mx.random.bernoulli(self._p_1, x.shape) |
|
|
|
return (1 / self._p_1) * mask * x |
|
|
|
|
|
class Dropout2d(Module): |
|
r"""Apply 2D channel-wise dropout during training. |
|
|
|
Randomly zero out entire channels independently with probability :math:`p`. |
|
This layer expects the channels to be last, i.e. the input shape should be |
|
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input |
|
image height,``W`` is the input image width, and``C`` is the number of |
|
input channels |
|
|
|
The remaining channels are scaled by :math:`\frac{1}{1-p}` to |
|
maintain the expected value of each element. Unlike traditional dropout, |
|
which zeros individual entries, this layer zeros entire channels. This is |
|
beneficial for early convolution layers where adjacent pixels are |
|
correlated. In such case, traditional dropout may not effectively |
|
regularize activations. For more details, see [1]. |
|
|
|
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. |
|
Efficient Object Localization Using Convolutional Networks. CVPR 2015. |
|
|
|
Args: |
|
p (float): Probability of zeroing a channel during training. |
|
""" |
|
|
|
def __init__(self, p: float = 0.5): |
|
super().__init__() |
|
|
|
if p < 0 or p >= 1: |
|
raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
|
self._p_1 = 1 - p |
|
|
|
def _extra_repr(self): |
|
return f"p={1-self._p_1}" |
|
|
|
def __call__(self, x): |
|
if x.ndim not in (3, 4): |
|
raise ValueError( |
|
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." |
|
) |
|
|
|
if self._p_1 == 1 or not self.training: |
|
return x |
|
|
|
|
|
|
|
|
|
mask_shape = x.shape |
|
mask_shape[-2] = mask_shape[-3] = 1 |
|
|
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) |
|
return (1 / self._p_1) * mask * x |
|
|
|
|
|
class Dropout3d(Module): |
|
r"""Apply 3D channel-wise dropout during training. |
|
|
|
Randomly zero out entire channels independently with probability :math:`p`. |
|
This layer expects the channels to be last, i.e., the input shape should be |
|
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth, |
|
`H` is the input image height, `W` is the input image width, and `C` is |
|
the number of input channels. |
|
|
|
The remaining channels are scaled by :math:`\frac{1}{1-p}` to |
|
maintain the expected value of each element. Unlike traditional dropout, |
|
which zeros individual entries, this layer zeros entire channels. This is |
|
often beneficial for convolutional layers processing 3D data, like in |
|
medical imaging or video processing. |
|
|
|
Args: |
|
p (float): Probability of zeroing a channel during training. |
|
""" |
|
|
|
def __init__(self, p: float = 0.5): |
|
super().__init__() |
|
|
|
if p < 0 or p >= 1: |
|
raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
|
self._p_1 = 1 - p |
|
|
|
def _extra_repr(self): |
|
return f"p={1-self._p_1}" |
|
|
|
def __call__(self, x): |
|
if x.ndim not in (4, 5): |
|
raise ValueError( |
|
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions." |
|
) |
|
|
|
if self._p_1 == 1 or not self.training: |
|
return x |
|
|
|
|
|
|
|
|
|
mask_shape = list(x.shape) |
|
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 |
|
|
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) |
|
return (1 / self._p_1) * mask * x |
|
|