|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|
if drop_prob == 0.0 or not training:
|
|
return x
|
|
keep_prob = 1 - drop_prob
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
if keep_prob > 0.0:
|
|
random_tensor.div_(keep_prob)
|
|
output = x * random_tensor
|
|
return output
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
|
|
def __init__(self, drop_prob=None):
|
|
super(DropPath, self).__init__()
|
|
self.drop_prob = drop_prob
|
|
|
|
def forward(self, x):
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|