|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility functions for blocks.""" |
|
|
|
from __future__ import division |
|
from __future__ import unicode_literals |
|
|
|
import math |
|
|
|
import numpy as np |
|
import six |
|
import tensorflow as tf |
|
|
|
|
|
class RsqrtInitializer(object): |
|
"""Gaussian initializer with standard deviation 1/sqrt(n). |
|
|
|
Note that tf.truncated_normal is used internally. Therefore any random sample |
|
outside two-sigma will be discarded and re-sampled. |
|
""" |
|
|
|
def __init__(self, dims=(0,), **kwargs): |
|
"""Creates an initializer. |
|
|
|
Args: |
|
dims: Dimension(s) index to compute standard deviation: |
|
1.0 / sqrt(product(shape[dims])) |
|
**kwargs: Extra keyword arguments to pass to tf.truncated_normal. |
|
""" |
|
if isinstance(dims, six.integer_types): |
|
self._dims = [dims] |
|
else: |
|
self._dims = dims |
|
self._kwargs = kwargs |
|
|
|
def __call__(self, shape, dtype): |
|
stddev = 1.0 / np.sqrt(np.prod([shape[x] for x in self._dims])) |
|
return tf.truncated_normal( |
|
shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) |
|
|
|
|
|
class RectifierInitializer(object): |
|
"""Gaussian initializer with standard deviation sqrt(2/fan_in). |
|
|
|
Note that tf.random_normal is used internally to ensure the expected weight |
|
distribution. This is intended to be used with ReLU activations, specially |
|
in ResNets. |
|
|
|
For details please refer to: |
|
Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet |
|
Classification |
|
""" |
|
|
|
def __init__(self, dims=(0,), scale=2.0, **kwargs): |
|
"""Creates an initializer. |
|
|
|
Args: |
|
dims: Dimension(s) index to compute standard deviation: |
|
sqrt(scale / product(shape[dims])) |
|
scale: A constant scaling for the initialization used as |
|
sqrt(scale / product(shape[dims])). |
|
**kwargs: Extra keyword arguments to pass to tf.truncated_normal. |
|
""" |
|
if isinstance(dims, six.integer_types): |
|
self._dims = [dims] |
|
else: |
|
self._dims = dims |
|
self._kwargs = kwargs |
|
self._scale = scale |
|
|
|
def __call__(self, shape, dtype): |
|
stddev = np.sqrt(self._scale / np.prod([shape[x] for x in self._dims])) |
|
return tf.random_normal( |
|
shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) |
|
|
|
|
|
class GaussianInitializer(object): |
|
"""Gaussian initializer with a given standard deviation. |
|
|
|
Note that tf.truncated_normal is used internally. Therefore any random sample |
|
outside two-sigma will be discarded and re-sampled. |
|
""" |
|
|
|
def __init__(self, stddev=1.0): |
|
self._stddev = stddev |
|
|
|
def __call__(self, shape, dtype): |
|
return tf.truncated_normal(shape=shape, dtype=dtype, stddev=self._stddev) |
|
|