|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from functools import partialmethod |
|
from typing import Union, List |
|
|
|
|
|
class Dropout(nn.Module): |
|
""" |
|
Implementation of dropout with the ability to share the dropout mask |
|
along a particular dimension. |
|
|
|
If not in training mode, this module computes the identity function. |
|
""" |
|
|
|
def __init__(self, r: float, batch_dim: Union[int, List[int]]): |
|
""" |
|
Args: |
|
r: |
|
Dropout rate |
|
batch_dim: |
|
Dimension(s) along which the dropout mask is shared |
|
""" |
|
super(Dropout, self).__init__() |
|
|
|
self.r = r |
|
if type(batch_dim) == int: |
|
batch_dim = [batch_dim] |
|
self.batch_dim = batch_dim |
|
self.dropout = nn.Dropout(self.r) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: |
|
Tensor to which dropout is applied. Can have any shape |
|
compatible with self.batch_dim |
|
""" |
|
shape = list(x.shape) |
|
if self.batch_dim is not None: |
|
for bd in self.batch_dim: |
|
shape[bd] = 1 |
|
mask = x.new_ones(shape) |
|
mask = self.dropout(mask) |
|
x *= mask |
|
return x |
|
|
|
|
|
class DropoutRowwise(Dropout): |
|
""" |
|
Convenience class for rowwise dropout as described in subsection |
|
1.11.6. |
|
""" |
|
|
|
__init__ = partialmethod(Dropout.__init__, batch_dim=-3) |
|
|
|
|
|
class DropoutColumnwise(Dropout): |
|
""" |
|
Convenience class for columnwise dropout as described in subsection |
|
1.11.6. |
|
""" |
|
|
|
__init__ = partialmethod(Dropout.__init__, batch_dim=-2) |
|
|