File size: 782 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch


def create_mask_sequence(mask_cfg, seq_len):
    type_name = mask_cfg['type']
    if type_name == 'raster order':
        num_tokens = mask_cfg['num_tokens']
        idx_list = []
        all_idx = torch.arange(seq_len)
        for i in range(0, seq_len, num_tokens):
            idx_list.append(all_idx[i: i + num_tokens])
        return idx_list
    elif type_name == 'random order':
        num_tokens = mask_cfg['num_tokens']
        idx_list = []
        all_idx = torch.randperm(seq_len)
        for i in range(0, seq_len, num_tokens):
            idx_list.append(all_idx[i: i + num_tokens])
        return idx_list
    elif type_name == 'single':
        idx_list = [torch.arange(seq_len)]
        return idx_list
    else:
        raise NotImplementedError()