import torch from torch.autograd import Variable from torch import nn class Normalize(nn.Module): def __init__(self, power=2): super(Normalize, self).__init__() self.power = power def forward(self, x): norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) out = x.div(norm) return out