File size: 4,529 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright © 2023 Apple Inc.
import mlx.core as mx
from mlx.nn.layers.base import Module
class Dropout(Module):
r"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
expected value of a given element will remain the same.
Args:
p (float): The probability to zero an element
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if self._p_1 == 1 or not self.training:
return x
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask * x
class Dropout2d(Module):
r"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
image height,``W`` is the input image width, and``C`` is the number of
input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
beneficial for early convolution layers where adjacent pixels are
correlated. In such case, traditional dropout may not effectively
regularize activations. For more details, see [1].
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (3, 4):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
class Dropout3d(Module):
r"""Apply 3D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e., the input shape should be
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
`H` is the input image height, `W` is the input image width, and `C` is
the number of input channels.
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
often beneficial for convolutional layers processing 3D data, like in
medical imaging or video processing.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (4, 5):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 4D input: (1, 1, 1, C)
# 5D input: (B, 1, 1, 1, C)
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
|