File size: 1,684 Bytes
0e73e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch import nn
import torch

def init_skim_predictor(module_list, mean_bias=5.0):
    for module in module_list:
        if not isinstance(module, torch.nn.Linear):
            raise ValueError("only support initialization of linear skim predictor")

        # module.bias.data[1].fill_(5.0)
        # module.bias.data[0].fill_(-5.0)
        # module.weight.data.zero_()
        module.bias.data[1].normal_(mean=mean_bias, std=0.02)
        module.bias.data[0].normal_(mean=-mean_bias, std=0.02)
        module.weight.data.normal_(mean=0.0, std=0.02)

        module._skim_initialized = True

class SkimPredictor(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=None):
        super().__init__()
        
        self.hidden_size = hidden_size if hidden_size else input_size

        self.predictor = nn.Sequential(
            nn.LayerNorm(input_size),
            nn.Linear(input_size, self.hidden_size),
            # nn.GELU(),
            # nn.Linear(self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, output_size),
        )

        init_skim_predictor([self.predictor[-1]])

    def forward(self, hidden_states):
        return self.predictor(hidden_states)

def test_init_skim_predictor():
    num_layers = 12

    skim_predictors = torch.nn.ModuleList([torch.nn.Linear(768,2) for _ in range(num_layers)])
    init_skim_predictor(skim_predictors)

    print(skim_predictors[0].weight, skim_predictors[0].bias)

    rand_input = torch.rand((4, 16, 768))
    print(skim_predictors[0](rand_input))

if __name__ == "__main__":
    test_init_skim_predictor()