|
|
|
|
|
|
|
|
|
|
|
import lightconv_cuda |
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import utils |
|
from fairseq.incremental_decoding_utils import with_incremental_state |
|
from fairseq.modules.fairseq_dropout import FairseqDropout |
|
from torch import nn |
|
from torch.autograd import Function |
|
|
|
|
|
class lightconvFunction(Function): |
|
@staticmethod |
|
def forward(ctx, x, weights, padding_l): |
|
ctx.padding_l = padding_l |
|
outputs = lightconv_cuda.forward(x, weights, padding_l) |
|
variables = [x, weights] |
|
ctx.save_for_backward(*variables) |
|
return outputs[0] |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
outputs = lightconv_cuda.backward( |
|
grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors |
|
) |
|
grad_input, grad_weights = outputs |
|
return grad_input, grad_weights, None |
|
|
|
|
|
@with_incremental_state |
|
class LightconvLayer(nn.Module): |
|
def __init__( |
|
self, |
|
input_size, |
|
kernel_size=1, |
|
padding_l=None, |
|
weight_softmax=False, |
|
num_heads=1, |
|
weight_dropout=0.0, |
|
bias=False, |
|
): |
|
super(LightconvLayer, self).__init__() |
|
self.input_size = input_size |
|
self.kernel_size = kernel_size |
|
self.padding_l = padding_l |
|
self.num_heads = num_heads |
|
self.weight_softmax = weight_softmax |
|
self.weight_dropout_module = FairseqDropout( |
|
weight_dropout, module_name=self.__class__.__name__ |
|
) |
|
|
|
self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) |
|
if bias: |
|
self.bias = nn.Parameter(torch.Tensor(input_size)) |
|
else: |
|
self.bias = None |
|
self.reset_parameters() |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
prefix = name + "." if name != "" else "" |
|
for k, v in state_dict.items(): |
|
if k.endswith(prefix + "weight"): |
|
if v.dim() == 3 and v.size(1) == 1: |
|
state_dict[k] = v.squeeze(1) |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.weight) |
|
if self.bias is not None: |
|
nn.init.constant_(self.bias, 0.0) |
|
|
|
def forward(self, x, incremental_state=None): |
|
|
|
|
|
if incremental_state is not None: |
|
T, B, C = x.size() |
|
K, H = self.kernel_size, self.num_heads |
|
R = C // H |
|
input_buffer = self._get_input_buffer(incremental_state) |
|
if input_buffer is None: |
|
input_buffer = x.new() |
|
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) |
|
if self.kernel_size > 1: |
|
self._set_input_buffer( |
|
incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] |
|
) |
|
x_unfold = x_unfold.view(T * B * H, R, -1) |
|
|
|
weight = self.weight |
|
if self.weight_softmax: |
|
weight = F.softmax(weight.float(), dim=1).type_as(weight) |
|
|
|
weight = weight[:, -x_unfold.size(2) :] |
|
|
|
K = weight.size(1) |
|
|
|
weight = ( |
|
weight.view(1, H, K) |
|
.expand(T * B, H, K) |
|
.contiguous() |
|
.view(T * B * H, K, 1) |
|
) |
|
|
|
weight = self.weight_dropout_module(weight) |
|
output = torch.bmm(x_unfold, weight) |
|
output = output.view(T, B, C) |
|
return output |
|
|
|
|
|
else: |
|
x = x.permute(1, 2, 0).contiguous() |
|
weight = self.weight |
|
if self.weight_softmax: |
|
weight = F.softmax(self.weight, -1) |
|
if self.weight_dropout_module.p: |
|
weight = self.weight_dropout_module(weight) |
|
return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) |
|
|
|
def reorder_incremental_state(self, incremental_state, new_order): |
|
input_buffer = self._get_input_buffer(incremental_state) |
|
if input_buffer is not None: |
|
input_buffer = input_buffer.index_select(1, new_order) |
|
self._set_input_buffer(incremental_state, input_buffer) |
|
|
|
def _get_input_buffer(self, incremental_state): |
|
return utils.get_incremental_state(self, incremental_state, "input_buffer") |
|
|
|
def _set_input_buffer(self, incremental_state, new_buffer): |
|
return utils.set_incremental_state( |
|
self, incremental_state, "input_buffer", new_buffer |
|
) |
|
|
|
def half(self): |
|
return self._apply(lambda t: t.half() if t.is_floating_point() else t) |
|
|