NMTKD / translation /OpenNMT-py /onmt /modules /sparse_activations.py
sakharamg's picture
Uploading all files
158b61b
"""
An implementation of sparsemax (Martins & Astudillo, 2016). See
:cite:`DBLP:journals/corr/MartinsA16` for detailed description.
By Ben Peters and Vlad Niculae
"""
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
import torch.nn as nn
def _make_ix_like(input, dim=0):
d = input.size(dim)
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
view = [1] * input.dim()
view[0] = -1
return rho.view(view).transpose(0, dim)
def _threshold_and_support(input, dim=0):
"""Sparsemax building block: compute the threshold
Args:
input: any dimension
dim: dimension along which to apply the sparsemax
Returns:
the threshold value
"""
input_srt, _ = torch.sort(input, descending=True, dim=dim)
input_cumsum = input_srt.cumsum(dim) - 1
rhos = _make_ix_like(input, dim)
support = rhos * input_srt > input_cumsum
support_size = support.sum(dim=dim).unsqueeze(dim)
tau = input_cumsum.gather(dim, support_size - 1)
tau /= support_size.to(input.dtype)
return tau, support_size
class SparsemaxFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, input, dim=0):
"""sparsemax: normalizing sparse transform (a la softmax)
Parameters:
input (Tensor): any shape
dim: dimension along which to apply sparsemax
Returns:
output (Tensor): same shape as input
"""
ctx.dim = dim
max_val, _ = input.max(dim=dim, keepdim=True)
input -= max_val # same numerical stability trick as for softmax
tau, supp_size = _threshold_and_support(input, dim=dim)
output = torch.clamp(input - tau, min=0)
ctx.save_for_backward(supp_size, output)
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
supp_size, output = ctx.saved_tensors
dim = ctx.dim
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
v_hat = v_hat.unsqueeze(dim)
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(nn.Module):
def __init__(self, dim=0):
self.dim = dim
super(Sparsemax, self).__init__()
def forward(self, input):
return sparsemax(input, self.dim)
class LogSparsemax(nn.Module):
def __init__(self, dim=0):
self.dim = dim
super(LogSparsemax, self).__init__()
def forward(self, input):
return torch.log(sparsemax(input, self.dim))