File size: 4,529 Bytes
f14e74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright © 2023 Apple Inc.

import mlx.core as mx
from mlx.nn.layers.base import Module


class Dropout(Module):
    r"""Randomly zero a portion of the elements during training.

    The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
    :math:`p` is the probability of zeroing an element. This is done so the
    expected value of a given element will remain the same.

    Args:
        p (float): The probability to zero an element
    """

    def __init__(self, p: float = 0.5):
        super().__init__()

        if p < 0 or p >= 1:
            raise ValueError(f"The dropout probability {p} is not in [0, 1)")

        self._p_1 = 1 - p

    def _extra_repr(self):
        return f"p={1-self._p_1}"

    def __call__(self, x):
        if self._p_1 == 1 or not self.training:
            return x

        mask = mx.random.bernoulli(self._p_1, x.shape)

        return (1 / self._p_1) * mask * x


class Dropout2d(Module):
    r"""Apply 2D channel-wise dropout during training.

    Randomly zero out entire channels independently with probability :math:`p`.
    This layer expects the channels to be last, i.e. the input shape should be
    ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
    image height,``W`` is the input image width, and``C`` is the number of
    input channels

    The remaining channels are scaled by :math:`\frac{1}{1-p}` to
    maintain the expected value of each element. Unlike traditional dropout,
    which zeros individual entries, this layer zeros entire channels. This is
    beneficial for early convolution layers where adjacent pixels are
    correlated. In such case, traditional dropout may not effectively
    regularize activations. For more details, see [1].

    [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
    Efficient Object Localization Using Convolutional Networks. CVPR 2015.

    Args:
        p (float): Probability of zeroing a channel during training.
    """

    def __init__(self, p: float = 0.5):
        super().__init__()

        if p < 0 or p >= 1:
            raise ValueError(f"The dropout probability {p} is not in [0, 1)")

        self._p_1 = 1 - p

    def _extra_repr(self):
        return f"p={1-self._p_1}"

    def __call__(self, x):
        if x.ndim not in (3, 4):
            raise ValueError(
                f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
            )

        if self._p_1 == 1 or not self.training:
            return x

        # Dropout is applied on the whole channel
        # 3D input: (1, 1, C)
        # 4D input: (B, 1, 1, C)
        mask_shape = x.shape
        mask_shape[-2] = mask_shape[-3] = 1

        mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
        return (1 / self._p_1) * mask * x


class Dropout3d(Module):
    r"""Apply 3D channel-wise dropout during training.

    Randomly zero out entire channels independently with probability :math:`p`.
    This layer expects the channels to be last, i.e., the input shape should be
    `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
    `H` is the input image height, `W` is the input image width, and `C` is
    the number of input channels.

    The remaining channels are scaled by :math:`\frac{1}{1-p}` to
    maintain the expected value of each element. Unlike traditional dropout,
    which zeros individual entries, this layer zeros entire channels. This is
    often beneficial for convolutional layers processing 3D data, like in
    medical imaging or video processing.

    Args:
        p (float): Probability of zeroing a channel during training.
    """

    def __init__(self, p: float = 0.5):
        super().__init__()

        if p < 0 or p >= 1:
            raise ValueError(f"The dropout probability {p} is not in [0, 1)")

        self._p_1 = 1 - p

    def _extra_repr(self):
        return f"p={1-self._p_1}"

    def __call__(self, x):
        if x.ndim not in (4, 5):
            raise ValueError(
                f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
            )

        if self._p_1 == 1 or not self.training:
            return x

        # Dropout is applied on the whole channel
        # 4D input: (1, 1, 1, C)
        # 5D input: (B, 1, 1, 1, C)
        mask_shape = list(x.shape)
        mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1

        mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
        return (1 / self._p_1) * mask * x