File size: 503 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""Create mask for subsequent steps."""


def make_history_mask(xp, block):
    """Prepare the history mask.

    Args:
        block (ndarray): Block with dimensions: (B x S).
    Returns:
        ndarray, np.ndarray: History mask with dimensions (B, S, S).

    """
    batch, length = block.shape
    arange = xp.arange(length)
    history_mask = (arange[None] <= arange[:, None])[
        None,
    ]
    history_mask = xp.broadcast_to(history_mask, (batch, length, length))
    return history_mask