File size: 2,034 Bytes
6cf191b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
#############################
# Imports
#############################
# Python modules
from typing import List
from random import randint
# Remote modules
import torch
# Local modules
from utils import Head_Mask
#############################
# Constants
#############################
#############################
# Stuff
#############################
def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None):
mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads))
if head_mask_type == Head_Mask.RANDOM:
for i in range(config.encoder_layers):
rand_idx = randint(0, config.encoder_attention_heads-1)
mask_heads[i, rand_idx] = 1
elif head_mask_type == Head_Mask.NONE:
mask_heads[:, :] = 1
elif head_mask_type == Head_Mask.ALL:
pass
elif head_mask_type == Head_Mask.SPECIFIC:
if specific_heads:
for layer_i in range(len(mask_heads)):
specific_head = specific_heads[layer_i] - 1
mask_heads[layer_i][specific_head] = 1
else:
mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]])
else:
raise NotImplementedError()
return mask_heads.tolist() |