mingyuan's picture
initial commit
373af33
raw
history blame
782 Bytes
import torch
def create_mask_sequence(mask_cfg, seq_len):
type_name = mask_cfg['type']
if type_name == 'raster order':
num_tokens = mask_cfg['num_tokens']
idx_list = []
all_idx = torch.arange(seq_len)
for i in range(0, seq_len, num_tokens):
idx_list.append(all_idx[i: i + num_tokens])
return idx_list
elif type_name == 'random order':
num_tokens = mask_cfg['num_tokens']
idx_list = []
all_idx = torch.randperm(seq_len)
for i in range(0, seq_len, num_tokens):
idx_list.append(all_idx[i: i + num_tokens])
return idx_list
elif type_name == 'single':
idx_list = [torch.arange(seq_len)]
return idx_list
else:
raise NotImplementedError()