File size: 782 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
def create_mask_sequence(mask_cfg, seq_len):
type_name = mask_cfg['type']
if type_name == 'raster order':
num_tokens = mask_cfg['num_tokens']
idx_list = []
all_idx = torch.arange(seq_len)
for i in range(0, seq_len, num_tokens):
idx_list.append(all_idx[i: i + num_tokens])
return idx_list
elif type_name == 'random order':
num_tokens = mask_cfg['num_tokens']
idx_list = []
all_idx = torch.randperm(seq_len)
for i in range(0, seq_len, num_tokens):
idx_list.append(all_idx[i: i + num_tokens])
return idx_list
elif type_name == 'single':
idx_list = [torch.arange(seq_len)]
return idx_list
else:
raise NotImplementedError()
|