import torch | |
from torch.autograd import Variable | |
from torch import nn | |
class Normalize(nn.Module): | |
def __init__(self, power=2): | |
super(Normalize, self).__init__() | |
self.power = power | |
def forward(self, x): | |
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) | |
out = x.div(norm) | |
return out | |