File size: 503 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
"""Create mask for subsequent steps."""
def make_history_mask(xp, block):
"""Prepare the history mask.
Args:
block (ndarray): Block with dimensions: (B x S).
Returns:
ndarray, np.ndarray: History mask with dimensions (B, S, S).
"""
batch, length = block.shape
arange = xp.arange(length)
history_mask = (arange[None] <= arange[:, None])[
None,
]
history_mask = xp.broadcast_to(history_mask, (batch, length, length))
return history_mask
|