Tianyinus's picture
init submit
edcf5ee verified
raw
history blame
351 Bytes
import torch
from torch.autograd import Variable
from torch import nn
class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
out = x.div(norm)
return out