|
|
|
|
|
|
|
|
|
|
|
"""Mask module.""" |
|
|
|
from distutils.version import LooseVersion |
|
|
|
import torch |
|
|
|
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") |
|
|
|
is_torch_1_2 = ( |
|
LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2") |
|
) |
|
datatype = torch.bool if is_torch_1_2_plus else torch.uint8 |
|
|
|
|
|
def subsequent_mask(size, device="cpu", dtype=datatype): |
|
"""Create mask for subsequent steps (size, size). |
|
|
|
:param int size: size of mask |
|
:param str device: "cpu" or "cuda" or torch.Tensor.device |
|
:param torch.dtype dtype: result dtype |
|
:rtype: torch.Tensor |
|
>>> subsequent_mask(3) |
|
[[1, 0, 0], |
|
[1, 1, 0], |
|
[1, 1, 1]] |
|
""" |
|
if is_torch_1_2 and dtype == torch.bool: |
|
|
|
ret = torch.ones(size, size, device=device, dtype=torch.uint8) |
|
return torch.tril(ret, out=ret).type(dtype) |
|
else: |
|
ret = torch.ones(size, size, device=device, dtype=dtype) |
|
return torch.tril(ret, out=ret) |
|
|
|
|
|
def target_mask(ys_in_pad, ignore_id): |
|
"""Create mask for decoder self-attention. |
|
|
|
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) |
|
:param int ignore_id: index of padding |
|
:param torch.dtype dtype: result dtype |
|
:rtype: torch.Tensor (B, Lmax, Lmax) |
|
""" |
|
ys_mask = ys_in_pad != ignore_id |
|
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) |
|
return ys_mask.unsqueeze(-2) & m |
|
|