File size: 3,640 Bytes
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
1615d09
2cdd41c
 
1615d09
2cdd41c
 
1615d09
2cdd41c
1615d09
2cdd41c
 
1615d09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import torch
import torch.nn as nn


class Initializer(object):
    def __init__(self, local_init=True, gamma=None):
        self.local_init = local_init
        self.gamma = gamma

    def __call__(self, m):
        if getattr(m, "__initialized", False):
            return

        if (
            isinstance(
                m,
                (
                    nn.BatchNorm1d,
                    nn.BatchNorm2d,
                    nn.BatchNorm3d,
                    nn.InstanceNorm1d,
                    nn.InstanceNorm2d,
                    nn.InstanceNorm3d,
                    nn.GroupNorm,
                    nn.SyncBatchNorm,
                ),
            )
            or "BatchNorm" in m.__class__.__name__
        ):
            if m.weight is not None:
                self._init_gamma(m.weight.data)
            if m.bias is not None:
                self._init_beta(m.bias.data)
        else:
            if getattr(m, "weight", None) is not None:
                self._init_weight(m.weight.data)
            if getattr(m, "bias", None) is not None:
                self._init_bias(m.bias.data)

        if self.local_init:
            object.__setattr__(m, "__initialized", True)

    def _init_weight(self, data):
        nn.init.uniform_(data, -0.07, 0.07)

    def _init_bias(self, data):
        nn.init.constant_(data, 0)

    def _init_gamma(self, data):
        if self.gamma is None:
            nn.init.constant_(data, 1.0)
        else:
            nn.init.normal_(data, 1.0, self.gamma)

    def _init_beta(self, data):
        nn.init.constant_(data, 0)


class Bilinear(Initializer):
    def __init__(self, scale, groups, in_channels, **kwargs):
        super().__init__(**kwargs)
        self.scale = scale
        self.groups = groups
        self.in_channels = in_channels

    def _init_weight(self, data):
        """Reset the weight and bias."""
        bilinear_kernel = self.get_bilinear_kernel(self.scale)
        weight = torch.zeros_like(data)
        for i in range(self.in_channels):
            if self.groups == 1:
                j = i
            else:
                j = 0
            weight[i, j] = bilinear_kernel
        data[:] = weight

    @staticmethod
    def get_bilinear_kernel(scale):
        """Generate a bilinear upsampling kernel."""
        kernel_size = 2 * scale - scale % 2
        scale = (kernel_size + 1) // 2
        center = scale - 0.5 * (1 + kernel_size % 2)

        og = np.ogrid[:kernel_size, :kernel_size]
        kernel = (1 - np.abs(og[0] - center) / scale) * (
            1 - np.abs(og[1] - center) / scale
        )

        return torch.tensor(kernel, dtype=torch.float32)


class XavierGluon(Initializer):
    def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3, **kwargs):
        super().__init__(**kwargs)

        self.rnd_type = rnd_type
        self.factor_type = factor_type
        self.magnitude = float(magnitude)

    def _init_weight(self, arr):
        fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)

        if self.factor_type == "avg":
            factor = (fan_in + fan_out) / 2.0
        elif self.factor_type == "in":
            factor = fan_in
        elif self.factor_type == "out":
            factor = fan_out
        else:
            raise ValueError("Incorrect factor type")
        scale = np.sqrt(self.magnitude / factor)

        if self.rnd_type == "uniform":
            nn.init.uniform_(arr, -scale, scale)
        elif self.rnd_type == "gaussian":
            nn.init.normal_(arr, 0, scale)
        else:
            raise ValueError("Unknown random type")