|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class BeamableMM(nn.Module): |
|
"""This module provides an optimized MM for beam decoding with attention. |
|
|
|
It leverage the fact that the source-side of the input is replicated beam |
|
times and the target-side of the input is of width one. This layer speeds up |
|
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} |
|
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. |
|
""" |
|
|
|
def __init__(self, beam_size=None): |
|
super(BeamableMM, self).__init__() |
|
self.beam_size = beam_size |
|
|
|
def forward(self, input1, input2): |
|
if ( |
|
not self.training |
|
and self.beam_size is not None |
|
and input1.dim() == 3 |
|
and input1.size(1) |
|
== 1 |
|
): |
|
bsz, beam = input1.size(0), self.beam_size |
|
|
|
|
|
input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) |
|
|
|
|
|
input2 = input2.unfold(0, beam, beam)[:, :, :, 0] |
|
|
|
|
|
if input1.size(0) == 1: |
|
output = torch.mm(input1[0, :, :], input2[0, :, :]) |
|
else: |
|
output = input1.bmm(input2) |
|
return output.view(bsz, 1, -1) |
|
else: |
|
return input1.bmm(input2) |
|
|
|
def set_beam_size(self, beam_size): |
|
self.beam_size = beam_size |
|
|