File size: 9,393 Bytes
ddbbf37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch.nn as nn
from .transformer import TransformerModel, EncoderLayer


# class TransformerModels(nn.Module):
class TransformerModels:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    """ ------------------------------- 1) Normalize ------------------------------- """

    def replace_InstanceNorm1d_LayerNorm(self):
        self.freeze_unfreeze(True)
        for name, layer in self.model.named_modules():
            if isinstance(layer, nn.InstanceNorm1d):
                num_features = layer.num_features
                new_layer = nn.LayerNorm(normalized_shape=num_features).to(self.device)
                parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]]
                setattr(parent_module, name.split('.')[-1], new_layer)

        return self.model

    def set_affine_true_for_instance_norm(self):
        self.freeze_unfreeze(True)
        for name, layer in self.model.named_modules():
            if isinstance(layer, nn.InstanceNorm1d):
                new_layer = nn.InstanceNorm1d(num_features=100, affine=True).to(self.device)
                parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]]
                setattr(parent_module, name.split('.')[-1], new_layer)

        return self.model

    """ ---------------------------------------------------------------------------- """
    """ -------------------------- 2) Activation Function -------------------------- """

    def replace_activation_function(self, activation):
        self.freeze_unfreeze(True)
        functions = {
            "GELU": nn.GELU(),
            "LeakyReLU": nn.LeakyReLU(),
            "ELU": nn.ELU(),
            "Mish": nn.Mish(),
            # "ReLU": nn.ReLU(),
        }

        def replace_activation_in_module(module, activation_layer):
            for name, child in module.named_children():
                if isinstance(child, nn.ReLU):
                    setattr(module, name, activation_layer)
                else:
                    replace_activation_in_module(child, activation_layer)

        new_activation_layer = functions[activation].to(self.device)
        replace_activation_in_module(self.model, new_activation_layer)
        return self.model

    """ ---------------------------------------------------------------------------- """
    """ ---------------------------- 3) New Encoder Layers ------------------------- """

    def add_encoder_layers(self, num_new_layers=2):
        self.freeze_unfreeze(True)
        new_encoder_layers = [EncoderLayer(512, 4, 0.1, nn.ReLU()).to(self.device) for _ in range(num_new_layers)]

        for i, new_layer in enumerate(new_encoder_layers):
            self.model.transformer_layers.insert(4 + i, new_layer.to(self.device))

        return self.model

    """ ---------------------------------------------------------------------------- """
    """ -------------------------------- 4) Dropout -------------------------------- """

    # def dropout_value_change(self, val=0.1):
    #     self.freeze_unfreeze(True)
    #     for layer in self.model.modules():
    #         if isinstance(layer, nn.Dropout):
    #             layer.p = val
    #
    #     return self.model

    def dropout_value_change(self, val=0.1):
        self.freeze_unfreeze(True)

        def replace_dropouts_in_module(module, rate):
            for name, child in module.named_children():
                if isinstance(child, nn.Dropout):
                    setattr(module, name, nn.Dropout(rate).to(self.device))
                else:
                    replace_dropouts_in_module(child, rate)

        replace_dropouts_in_module(self.model, val)

        return self.model

    """ ---------------------------------------------------------------------------- """
    """ ------------------------- 5) Output linear layers -------------------------- """

    def change_linear_output_layers(self):
        output_layers_names = [
            "output_linear1",
            "output_linear2",
            "output_linear3",
            "output_linear_bin1",
            "output_linear_bin2",
            "output_linear_bin3",
        ]
        for name, param in self.model.named_parameters():
            param.requires_grad = False
            if name.split(".")[0] in output_layers_names:
                param.requires_grad = True

        output_linear1 = self.model.output_linear1
        output_linear2 = self.model.output_linear2
        output_linear3 = self.model.output_linear3
        output_linear_bin1 = self.model.output_linear_bin1
        output_linear_bin2 = self.model.output_linear_bin2
        output_linear_bin3 = self.model.output_linear_bin3

        output_linear11 = nn.Linear(output_linear1.out_features,
                                    output_linear1.out_features).to(self.device)
        output_linear21 = nn.Linear(output_linear2.out_features,
                                    output_linear2.out_features).to(self.device)

        # self.model.output_layers = nn.Sequential(
        #     output_linear1,
        #     output_linear11,
        #     output_linear2,
        #     output_linear21,
        #     output_linear3,
        #     output_linear_bin1,
        #     output_linear_bin2,
        #     output_linear_bin3,
        # )
        self.model.insert(6, output_linear11)
        self.model.insert(8, output_linear21)

        return self.model

    # def change_linear_output_layers(self):
    #     output_layers_names = [
    #         "output_linear1",
    #         "output_linear2",
    #         "output_linear3",
    #         "output_linear_bin1",
    #         "output_linear_bin2",
    #         "output_linear_bin3",
    #     ]
    #     for name, param in self.model.named_parameters():
    #         param.requires_grad = False
    #         if name.split(".")[0] in output_layers_names:
    #             param.requires_grad = True
    #
    #     output_linear1 = self.model.output_linear1
    #     output_linear2 = self.model.output_linear2
    #     output_linear3 = self.model.output_linear3
    #     # output_linear_bin1 = self.model.output_linear_bin1
    #     # output_linear_bin2 = self.model.output_linear_bin2
    #     # output_linear_bin3 = self.model.output_linear_bin3
    #
    #     output_linear11 = nn.Linear(output_linear1.out_features,
    #                                 output_linear1.out_features).to(self.device)
    #     output_linear21 = nn.Linear(output_linear2.out_features,
    #                                 output_linear2.out_features).to(self.device)
    #
    #     self.model.output_linear1.append(output_linear11.to(self.device))
    #     self.model.output_linear2.append(output_linear21.to(self.device))
    #
    #     # self.model.output_layers = nn.Sequential(
    #     #     output_linear1,
    #     #     output_linear11,
    #     #     output_linear2,
    #     #     output_linear21,
    #     #     output_linear3,
    #     #     output_linear_bin1,
    #     #     output_linear_bin2,
    #     #     output_linear_bin3,
    #     # )
    #
    #     return self.model

    """ ---------------------------------------------------------------------------- """
    """ ---------------------------- 6) Cross-Attention ---------------------------- """

    def add_cross_attention(self, embed_dim=512, num_heads=8, dropout=0.1):
        self.freeze_unfreeze(True)
        for idx, layer in enumerate(self.model.transformer_layers):
            cross_attn_layer = CrossAttentionLayer(embed_dim, num_heads, dropout).to(self.device)
            layer.gen_attn = nn.Sequential(layer.gen_attn, cross_attn_layer).to(self.device)

        return self.model

    """ ---------------------------------------------------------------------------- """
    """ -------------------------- 7) Residual Connections? ------------------------- """

    """ ---------------------------------------------------------------------------- """
    """ ------------------------------- 8) Attention Heads? (check if works with same params) ------------------------------- """

    """ ---------------------------------------------------------------------------- """
    #Add LayerNorm Before/After Attention

    # ADAM ?
    # weight decay ?
    # learning rate?

    def freeze_unfreeze(self, flag):
        for param in self.model.parameters():
            param.requires_grad = flag

    def count_parameters(self):
        model = self.model
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        untrainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)

        print(f"Trainable parameters: {trainable_params}")
        print(f"Untrainable parameters: {untrainable_params}")
        return trainable_params, untrainable_params


class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(CrossAttentionLayer, self).__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key_value, attn_mask=None):
        attn_output, _ = self.cross_attn(query, key_value, key_value, attn_mask=attn_mask)
        return self.norm(self.dropout(attn_output) + query)