File size: 1,514 Bytes
c10198c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch.nn as nn
import torch
from torch.nn.utils.weight_norm import WeightNorm
class distLinear(nn.Module):
def __init__(self, indim, outdim):
super(distLinear, self).__init__()
self.L = nn.Linear( indim, outdim, bias = False)
self.class_wise_learnable_norm = True #See the issue#4&8 in the github
if self.class_wise_learnable_norm:
WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm
if outdim <=200:
self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github
else:
self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes.
def forward(self, x):
x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
x_normalized = x.div(x_norm+ 0.00001)
if not self.class_wise_learnable_norm:
L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
scores = self.scale_factor* (cos_dist)
return scores |