|
""" |
|
An implementation of sparsemax (Martins & Astudillo, 2016). See |
|
:cite:`DBLP:journals/corr/MartinsA16` for detailed description. |
|
|
|
By Ben Peters and Vlad Niculae |
|
""" |
|
|
|
import torch |
|
from torch.autograd import Function |
|
from torch.cuda.amp import custom_fwd, custom_bwd |
|
import torch.nn as nn |
|
|
|
|
|
def _make_ix_like(input, dim=0): |
|
d = input.size(dim) |
|
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) |
|
view = [1] * input.dim() |
|
view[0] = -1 |
|
return rho.view(view).transpose(0, dim) |
|
|
|
|
|
def _threshold_and_support(input, dim=0): |
|
"""Sparsemax building block: compute the threshold |
|
|
|
Args: |
|
input: any dimension |
|
dim: dimension along which to apply the sparsemax |
|
|
|
Returns: |
|
the threshold value |
|
""" |
|
|
|
input_srt, _ = torch.sort(input, descending=True, dim=dim) |
|
input_cumsum = input_srt.cumsum(dim) - 1 |
|
rhos = _make_ix_like(input, dim) |
|
support = rhos * input_srt > input_cumsum |
|
|
|
support_size = support.sum(dim=dim).unsqueeze(dim) |
|
tau = input_cumsum.gather(dim, support_size - 1) |
|
tau /= support_size.to(input.dtype) |
|
return tau, support_size |
|
|
|
|
|
class SparsemaxFunction(Function): |
|
|
|
@staticmethod |
|
@custom_fwd |
|
def forward(ctx, input, dim=0): |
|
"""sparsemax: normalizing sparse transform (a la softmax) |
|
|
|
Parameters: |
|
input (Tensor): any shape |
|
dim: dimension along which to apply sparsemax |
|
|
|
Returns: |
|
output (Tensor): same shape as input |
|
""" |
|
ctx.dim = dim |
|
max_val, _ = input.max(dim=dim, keepdim=True) |
|
input -= max_val |
|
tau, supp_size = _threshold_and_support(input, dim=dim) |
|
output = torch.clamp(input - tau, min=0) |
|
ctx.save_for_backward(supp_size, output) |
|
return output |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, grad_output): |
|
supp_size, output = ctx.saved_tensors |
|
dim = ctx.dim |
|
grad_input = grad_output.clone() |
|
grad_input[output == 0] = 0 |
|
|
|
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() |
|
v_hat = v_hat.unsqueeze(dim) |
|
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) |
|
return grad_input, None |
|
|
|
|
|
sparsemax = SparsemaxFunction.apply |
|
|
|
|
|
class Sparsemax(nn.Module): |
|
|
|
def __init__(self, dim=0): |
|
self.dim = dim |
|
super(Sparsemax, self).__init__() |
|
|
|
def forward(self, input): |
|
return sparsemax(input, self.dim) |
|
|
|
|
|
class LogSparsemax(nn.Module): |
|
|
|
def __init__(self, dim=0): |
|
self.dim = dim |
|
super(LogSparsemax, self).__init__() |
|
|
|
def forward(self, input): |
|
return torch.log(sparsemax(input, self.dim)) |
|
|