|
""" Weights normalization modules """ |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import Parameter |
|
|
|
|
|
def get_var_maybe_avg(namespace, var_name, training, polyak_decay): |
|
""" utility for retrieving polyak averaged params |
|
Update average |
|
""" |
|
v = getattr(namespace, var_name) |
|
v_avg = getattr(namespace, var_name + '_avg') |
|
v_avg -= (1 - polyak_decay) * (v_avg - v.data) |
|
|
|
if training: |
|
return v |
|
else: |
|
return v_avg |
|
|
|
|
|
def get_vars_maybe_avg(namespace, var_names, training, polyak_decay): |
|
""" utility for retrieving polyak averaged params """ |
|
vars = [] |
|
for vn in var_names: |
|
vars.append(get_var_maybe_avg( |
|
namespace, vn, training, polyak_decay)) |
|
return vars |
|
|
|
|
|
class WeightNormLinear(nn.Linear): |
|
""" |
|
Implementation of "Weight Normalization: A Simple Reparameterization |
|
to Accelerate Training of Deep Neural Networks" |
|
:cite:`DBLP:journals/corr/SalimansK16` |
|
|
|
As a reparameterization method, weight normalization is same |
|
as BatchNormalization, but it doesn't depend on minibatch. |
|
|
|
NOTE: This is used nowhere in the code at this stage |
|
Vincent Nguyen 05/18/2018 |
|
""" |
|
|
|
def __init__(self, in_features, out_features, |
|
init_scale=1., polyak_decay=0.9995): |
|
super(WeightNormLinear, self).__init__( |
|
in_features, out_features, bias=True) |
|
|
|
self.V = self.weight |
|
self.g = Parameter(torch.Tensor(out_features)) |
|
self.b = self.bias |
|
|
|
self.register_buffer( |
|
'V_avg', torch.zeros(out_features, in_features)) |
|
self.register_buffer('g_avg', torch.zeros(out_features)) |
|
self.register_buffer('b_avg', torch.zeros(out_features)) |
|
|
|
self.init_scale = init_scale |
|
self.polyak_decay = polyak_decay |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
return |
|
|
|
def forward(self, x, init=False): |
|
if init is True: |
|
|
|
self.V.data.copy_(torch.randn(self.V.data.size()).type_as( |
|
self.V.data) * 0.05) |
|
|
|
v_norm = self.V.data / \ |
|
self.V.data.norm(2, 1).expand_as(self.V.data) |
|
|
|
x_init = F.linear(x, v_norm).data |
|
|
|
m_init, v_init = x_init.mean(0).squeeze( |
|
0), x_init.var(0).squeeze(0) |
|
|
|
scale_init = self.init_scale / \ |
|
torch.sqrt(v_init + 1e-10) |
|
self.g.data.copy_(scale_init) |
|
self.b.data.copy_(-m_init * scale_init) |
|
x_init = scale_init.view(1, -1).expand_as(x_init) \ |
|
* (x_init - m_init.view(1, -1).expand_as(x_init)) |
|
self.V_avg.copy_(self.V.data) |
|
self.g_avg.copy_(self.g.data) |
|
self.b_avg.copy_(self.b.data) |
|
return x_init |
|
else: |
|
v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], |
|
self.training, |
|
polyak_decay=self.polyak_decay) |
|
|
|
x = F.linear(x, v) |
|
scalar = g / torch.norm(v, 2, 1).squeeze(1) |
|
x = scalar.view(1, -1).expand_as(x) * x + \ |
|
b.view(1, -1).expand_as(x) |
|
return x |
|
|
|
|
|
class WeightNormConv2d(nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, init_scale=1., |
|
polyak_decay=0.9995): |
|
super(WeightNormConv2d, self).__init__(in_channels, out_channels, |
|
kernel_size, stride, padding, |
|
dilation, groups) |
|
|
|
self.V = self.weight |
|
self.g = Parameter(torch.Tensor(out_channels)) |
|
self.b = self.bias |
|
|
|
self.register_buffer('V_avg', torch.zeros(self.V.size())) |
|
self.register_buffer('g_avg', torch.zeros(out_channels)) |
|
self.register_buffer('b_avg', torch.zeros(out_channels)) |
|
|
|
self.init_scale = init_scale |
|
self.polyak_decay = polyak_decay |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
return |
|
|
|
def forward(self, x, init=False): |
|
if init is True: |
|
|
|
self.V.data.copy_(torch.randn(self.V.data.size() |
|
).type_as(self.V.data) * 0.05) |
|
v_norm = self.V.data / self.V.data.view(self.out_channels, -1)\ |
|
.norm(2, 1).view(self.out_channels, *( |
|
[1] * (len(self.kernel_size) + 1))).expand_as(self.V.data) |
|
x_init = F.conv2d(x, v_norm, None, self.stride, |
|
self.padding, self.dilation, self.groups).data |
|
t_x_init = x_init.transpose(0, 1).contiguous().view( |
|
self.out_channels, -1) |
|
m_init, v_init = t_x_init.mean(1).squeeze( |
|
1), t_x_init.var(1).squeeze(1) |
|
|
|
scale_init = self.init_scale / \ |
|
torch.sqrt(v_init + 1e-10) |
|
self.g.data.copy_(scale_init) |
|
self.b.data.copy_(-m_init * scale_init) |
|
scale_init_shape = scale_init.view( |
|
1, self.out_channels, *([1] * (len(x_init.size()) - 2))) |
|
m_init_shape = m_init.view( |
|
1, self.out_channels, *([1] * (len(x_init.size()) - 2))) |
|
x_init = scale_init_shape.expand_as( |
|
x_init) * (x_init - m_init_shape.expand_as(x_init)) |
|
self.V_avg.copy_(self.V.data) |
|
self.g_avg.copy_(self.g.data) |
|
self.b_avg.copy_(self.b.data) |
|
return x_init |
|
else: |
|
v, g, b = get_vars_maybe_avg( |
|
self, ['V', 'g', 'b'], self.training, |
|
polyak_decay=self.polyak_decay) |
|
|
|
scalar = torch.norm(v.view(self.out_channels, -1), 2, 1) |
|
if len(scalar.size()) == 2: |
|
scalar = g / scalar.squeeze(1) |
|
else: |
|
scalar = g / scalar |
|
|
|
w = scalar.view(self.out_channels, * |
|
([1] * (len(v.size()) - 1))).expand_as(v) * v |
|
|
|
x = F.conv2d(x, w, b, self.stride, |
|
self.padding, self.dilation, self.groups) |
|
return x |
|
|
|
|
|
|
|
|
|
class WeightNormConvTranspose2d(nn.ConvTranspose2d): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, output_padding=0, groups=1, init_scale=1., |
|
polyak_decay=0.9995): |
|
super(WeightNormConvTranspose2d, self).__init__( |
|
in_channels, out_channels, |
|
kernel_size, stride, |
|
padding, output_padding, |
|
groups) |
|
|
|
self.V = self.weight |
|
self.g = Parameter(torch.Tensor(out_channels)) |
|
self.b = self.bias |
|
|
|
self.register_buffer('V_avg', torch.zeros(self.V.size())) |
|
self.register_buffer('g_avg', torch.zeros(out_channels)) |
|
self.register_buffer('b_avg', torch.zeros(out_channels)) |
|
|
|
self.init_scale = init_scale |
|
self.polyak_decay = polyak_decay |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
return |
|
|
|
def forward(self, x, init=False): |
|
if init is True: |
|
|
|
self.V.data.copy_(torch.randn(self.V.data.size()).type_as( |
|
self.V.data) * 0.05) |
|
v_norm = self.V.data / self.V.data.transpose(0, 1).contiguous() \ |
|
.view(self.out_channels, -1).norm(2, 1).view( |
|
self.in_channels, self.out_channels, |
|
*([1] * len(self.kernel_size))).expand_as(self.V.data) |
|
x_init = F.conv_transpose2d( |
|
x, v_norm, None, self.stride, |
|
self.padding, self.output_padding, self.groups).data |
|
|
|
t_x_init = x_init.tranpose(0, 1).contiguous().view( |
|
self.out_channels, -1) |
|
|
|
m_init, v_init = t_x_init.mean(1).squeeze( |
|
1), t_x_init.var(1).squeeze(1) |
|
|
|
scale_init = self.init_scale / \ |
|
torch.sqrt(v_init + 1e-10) |
|
self.g.data.copy_(scale_init) |
|
self.b.data.copy_(-m_init * scale_init) |
|
scale_init_shape = scale_init.view( |
|
1, self.out_channels, *([1] * (len(x_init.size()) - 2))) |
|
m_init_shape = m_init.view( |
|
1, self.out_channels, *([1] * (len(x_init.size()) - 2))) |
|
|
|
x_init = scale_init_shape.expand_as(x_init)\ |
|
* (x_init - m_init_shape.expand_as(x_init)) |
|
self.V_avg.copy_(self.V.data) |
|
self.g_avg.copy_(self.g.data) |
|
self.b_avg.copy_(self.b.data) |
|
return x_init |
|
else: |
|
v, g, b = get_vars_maybe_avg( |
|
self, ['V', 'g', 'b'], self.training, |
|
polyak_decay=self.polyak_decay) |
|
scalar = g / \ |
|
torch.norm(v.transpose(0, 1).contiguous().view( |
|
self.out_channels, -1), 2, 1).squeeze(1) |
|
w = scalar.view(self.in_channels, self.out_channels, |
|
*([1] * (len(v.size()) - 2))).expand_as(v) * v |
|
|
|
x = F.conv_transpose2d(x, w, b, self.stride, |
|
self.padding, self.output_padding, |
|
self.groups) |
|
return x |
|
|