import numpy as np import torch def window_mask(x_len, device, m_len=0, size=(1,1)): win_size,k = size mem_mask = torch.zeros((x_len,m_len), device=device) tri_mask = torch.triu(torch.ones((x_len//win_size+1,x_len//win_size+1), device=device),diagonal=k) window_mask = tri_mask.repeat_interleave(win_size,dim=0).repeat_interleave(win_size,dim=1)[:x_len,:x_len] if x_len: window_mask[...,0] = 0 # Always allowing first index to see. Otherwise you'll get NaN loss mask = torch.cat((mem_mask, window_mask), dim=1)[None,None] return mask.bool() if hasattr(mask, 'bool') else mask.byte() def rand_window_mask(x_len,m_len,device,max_size:int=None,p:float=0.2,is_eval:bool=False): if is_eval or np.random.rand() >= p or max_size is None: win_size,k = (1,1) else: win_size,k = (np.random.randint(0,max_size)+1,0) return window_mask(x_len, device, m_len, size=(win_size,k)) def lm_mask(x_len, device): mask = torch.triu(torch.ones((x_len, x_len), device=device), diagonal=1)[None,None] return mask.bool() if hasattr(mask, 'bool') else mask.byte()