Spaces:
Runtime error
Runtime error
File size: 4,958 Bytes
61c2d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn as nn
import torch.nn.functional as functional
try:
from queue import Queue
except ImportError:
from Queue import Queue
from .functions import *
class ABN(nn.Module):
"""Activated Batch Normalization
This gathers a `BatchNorm2d` and an activation function in a single module
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
"""Creates an Activated Batch Normalization module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics as.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
activation : str
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
"""
super(ABN, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
self.activation = activation
self.slope = slope
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.running_mean, 0)
nn.init.constant_(self.running_var, 1)
if self.affine:
nn.init.constant_(self.weight, 1)
nn.init.constant_(self.bias, 0)
def forward(self, x):
x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
if self.activation == ACT_RELU:
return functional.relu(x, inplace=True)
elif self.activation == ACT_LEAKY_RELU:
return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
elif self.activation == ACT_ELU:
return functional.elu(x, inplace=True)
else:
return x
def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
' affine={affine}, activation={activation}'
if self.activation == "leaky_relu":
rep += ', slope={slope})'
else:
rep += ')'
return rep.format(name=self.__class__.__name__, **self.__dict__)
class InPlaceABN(ABN):
"""InPlace Activated Batch Normalization"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
"""Creates an InPlace Activated Batch Normalization module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics as.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
activation : str
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
"""
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
def forward(self, x):
x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.slope)
return x
class InPlaceABNSync(ABN):
"""InPlace Activated Batch Normalization with cross-GPU synchronization
This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
"""
def forward(self, x):
x, _, _ = inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.slope)
return x
def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
' affine={affine}, activation={activation}'
if self.activation == "leaky_relu":
rep += ', slope={slope})'
else:
rep += ')'
return rep.format(name=self.__class__.__name__, **self.__dict__)
|