File size: 57,614 Bytes
3771cfd
 
 
 
 
 
 
 
 
 
7f39bd4
3771cfd
 
 
 
 
 
 
7f39bd4
3771cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
# modeling_mrt5.py
# Author: Julie Kallini
# Description: This file contains the implementation of the MrT5 model.
# The code is adapted from HuggingFace's modeling_t5.py. New code sequences
# are labeled with comments.

import torch
import copy
import numpy as np
from torch import nn
from .modeling_t5 import (
    T5Attention,
    T5LayerNorm,
    T5LayerFF,
    T5Stack,
    T5ForConditionalGeneration,
    softmax1,
)
from .configuration_mrt5 import MrT5Config
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
)
from transformers.utils import logging
from typing import Optional, Tuple, Union
from dataclasses import dataclass

logger = logging.get_logger(__name__)

@dataclass
class MrT5BaseModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
    delete_gate_mask: torch.FloatTensor = None
    delete_gate_output: torch.FloatTensor = None
    delete_gate_logits: torch.FloatTensor = None
    attention_mask: torch.FloatTensor = None
    attention_queries: torch.FloatTensor = None
    attention_keys: torch.FloatTensor = None
    attention_values: torch.FloatTensor = None
    attention_scores: torch.FloatTensor = None
    cross_attention_keys: torch.FloatTensor = None
    cross_attention_queries: torch.FloatTensor = None
    cross_attention_values: torch.FloatTensor = None
    cross_attention_scores: torch.FloatTensor = None


@dataclass
class MrT5Seq2SeqLMOutput(Seq2SeqLMOutput):
    delete_gate_mask: torch.FloatTensor = None
    delete_gate_output: torch.FloatTensor = None
    delete_gate_logits: torch.FloatTensor = None
    encoder_keys: torch.FloatTensor = None
    encoder_queries: torch.FloatTensor = None
    encoder_values: torch.FloatTensor = None
    encoder_scores: torch.FloatTensor = None
    decoder_keys: torch.FloatTensor = None
    decoder_queries: torch.FloatTensor = None
    decoder_values: torch.FloatTensor = None
    decoder_scores: torch.FloatTensor = None
    cross_attention_keys: torch.FloatTensor = None
    cross_attention_queries: torch.FloatTensor = None
    cross_attention_values: torch.FloatTensor = None
    cross_attention_scores: torch.FloatTensor = None


TORCH_INIT_FUNCTIONS = {
    "uniform_": nn.init.uniform_,
    "normal_": nn.init.normal_,
    "trunc_normal_": nn.init.trunc_normal_,
    "constant_": nn.init.constant_,
    "xavier_uniform_": nn.init.xavier_uniform_,
    "xavier_normal_": nn.init.xavier_normal_,
    "kaiming_uniform_": nn.init.kaiming_uniform_,
    "kaiming_normal_": nn.init.kaiming_normal_,
    "uniform": nn.init.uniform,
    "normal": nn.init.normal,
    "xavier_uniform": nn.init.xavier_uniform,
    "xavier_normal": nn.init.xavier_normal,
    "kaiming_uniform": nn.init.kaiming_uniform,
    "kaiming_normal": nn.init.kaiming_normal,
}

class ScaledSigmoid(nn.Module):
    def __init__(self, sigmoid_mask_scale):
        super().__init__()
        self.sigmoid_mask_scale = sigmoid_mask_scale

    def forward(self, input):
        return self.sigmoid_mask_scale * torch.sigmoid(-input)

def gumbel_noise_like(x: torch.Tensor) -> torch.Tensor:
    eps = 3e-4 if x.dtype == torch.float16 else 1e-10
    uniform = torch.empty_like(x).uniform_(eps, 1 - eps)
    return - (- uniform.log()).log()

class SigmoidDeleteGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.has_layer_norm = config.gate_layer_norm
        if self.has_layer_norm:
            self.layer_norm = T5LayerNorm(config.hidden_size)
        self.feed_forward = nn.Linear(config.hidden_size, 1)
        self._init_weights(self.feed_forward)
        self.activation = ScaledSigmoid(config.sigmoid_mask_scale)
        self.use_gumbel_noise = config.use_gumbel_noise

    def forward(self, hidden_states, input_ids):
        if self.has_layer_norm:
            hidden_states = self.layer_norm(hidden_states)
        delete_gate_logits = self.feed_forward(hidden_states)

        # Add gumbel noise to the delete gate logits
        if self.training and self.use_gumbel_noise:
            gumbel_noise = gumbel_noise_like(delete_gate_logits)
            delete_gate_logits += gumbel_noise

        gate_values = self.activation(delete_gate_logits)

        # Check if there are any pad tokens in input_ids
        if (input_ids == 0).any():
            # Set gate values for pad tokens (input_ids == 0) to sigmoid_mask_scale
            pad_mask = (input_ids == 0).unsqueeze(-1)
            gate_values = torch.where(pad_mask, torch.tensor(self.activation.sigmoid_mask_scale), gate_values)

        return gate_values, delete_gate_logits

    def _init_weights(self, m, init_func="xavier_uniform_"):
        # Initialize the weights. This is necessary because
        # HuggingFace disables initialization during "from_pretrained"
        if isinstance(m, nn.Linear):
            TORCH_INIT_FUNCTIONS[init_func](m.weight)
            m.bias.data.fill_(1)


class LogSigmoidDeleteGate(SigmoidDeleteGate):
    def __init__(self, config):
        super().__init__(config)
        self.activation = nn.LogSigmoid()

class RandomDeleteGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Store the sigmoid_mask_scale and the probability of activation
        self.sigmoid_mask_scale = config.sigmoid_mask_scale
        self.random_deletion_probability = config.random_deletion_probability

    def __random_mask_tensor(self, x, n):
        # Determine the shape for the output tensor
        target_shape = (x.shape[0], x.shape[1], 1)
        total_elements = x.shape[0] * x.shape[1]
        
        # Create a flattened float tensor of all 0.0
        flat_tensor = torch.zeros(total_elements, dtype=torch.float32, device=x.device)
        
        # Randomly select n indices to be set to 1.0
        indices = torch.randperm(total_elements)[:n]
        flat_tensor[indices] = 1.0
        
        # Reshape it to match the desired target shape
        float_tensor = flat_tensor.view(target_shape)
        
        return float_tensor

    def forward(self, hidden_states, input_ids):
        # Calculate the number of tokens to delete using a gaussian
        deletion_percentage = np.random.normal(loc=self.random_deletion_probability, scale=0.05)
        n_deletions = int(deletion_percentage * hidden_states.shape[0] * hidden_states.shape[1])
        
        # Create a random mask with n_deletions True values
        random_mask = self.__random_mask_tensor(hidden_states, n_deletions)
        
        # Scale the mask by sigmoid_mask_scale
        delete_gate_mask = random_mask * self.sigmoid_mask_scale
        return delete_gate_mask, delete_gate_mask

    
class FixedDeleteGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.sigmoid_mask_scale = config.sigmoid_mask_scale
        self.fixed_deletion_amount = config.fixed_deletion_amount
        self.sep_tokens = torch.tensor([12, 13, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
                                        46, 47, 48, 49, 50, 61, 62, 63, 64, 65, 66, 67, 94,
                                        95, 96, 97, 98, 99, 126, 127, 128, 129, 1])

    def __create_mask(self, input_ids):
        device = input_ids.device
        batch_size, seq_len = input_ids.size()
        self.sep_tokens = self.sep_tokens.to(device)
        
        # Create an initial mask filled with sigmoid_mask_scale
        mask = torch.full((batch_size, seq_len), self.sigmoid_mask_scale, device=device)
        
        # Find sep_token indices
        is_sep = torch.isin(input_ids, self.sep_tokens)

        # Create a tensor of segment lengths
        sep_positions = torch.cumsum(is_sep, dim=1)
        segment_lengths = torch.zeros_like(input_ids, dtype=torch.float)
        segment_lengths[:, 1:] = (sep_positions[:, 1:] != sep_positions[:, :-1]).float()
        segment_lengths[:, 0] = 1.0
        segment_lengths = torch.cumsum(segment_lengths, dim=1)
        
        # Calculate number of zeros for each segment
        segment_counts = torch.bincount(sep_positions.view(-1), minlength=seq_len)
        segment_starts = torch.cumsum(torch.cat([torch.tensor([0], device=device), segment_counts[:-1]]), dim=0)
        segment_ends = torch.cumsum(segment_counts, dim=0)
        num_zeros = torch.ceil((1 - self.fixed_deletion_amount) * (segment_ends - segment_starts)).long()
        
        # Create the mask based on the calculated number of zeros
        for i in range(batch_size):
            for start, count in zip(segment_starts, num_zeros):
                mask[i, start:start + count] = 0
        
        return mask.to(torch.float)

    def forward(self, hidden_states, input_ids):
        delete_gate_mask = self.__create_mask(input_ids).unsqueeze(-1)
        return delete_gate_mask, delete_gate_mask


class MrT5Attention(T5Attention):
    """
    Extends the T5Attention class to include a delete gate. Only the forward
    method is modified. The delete_gate_mask passed to the forward function
    is applied to the attention scores.
    """

    def __init__(self, config: MrT5Config, has_relative_attention_bias=False):
        super().__init__(config, has_relative_attention_bias)
        #### NEW CODE ####
        self.use_softmax1 = config.use_softmax1 
        #### NEW CODE ####

    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
        #### NEW CODE ####
        delete_gate_mask=None,
        #### NEW CODE ####
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            if len(past_key_value) != 2:
                raise ValueError(
                    f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
                )
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[
            1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """reshape"""
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat(
                        [past_key_value, hidden_states], dim=2)
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        # (batch_size, n_heads, seq_length, dim_per_head)
        query_states = shape(self.q(hidden_states))

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[
                0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[
                1] if past_key_value is not None else None
        )

        # compute scores
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        #### NEW CODE ####
        if not self.has_absolute_position_embeddings:
        #### NEW CODE ####
            if position_bias is None:
                if not self.has_relative_attention_bias:
                    position_bias = torch.zeros(
                        (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                    )
                    if self.gradient_checkpointing and self.training:
                        position_bias.requires_grad = True
                else:
                    position_bias = self.compute_bias(
                        real_seq_length, key_length, device=scores.device)

                # if key and values are already calculated
                # we want only the last query position bias
                if past_key_value is not None:
                    position_bias = position_bias[:, :, -hidden_states.size(1):, :]

                if mask is not None:
                    # (batch_size, n_heads, seq_length, key_length)
                    position_bias = position_bias + mask

            if self.pruned_heads:
                mask = torch.ones(position_bias.shape[1])
                mask[list(self.pruned_heads)] = 0
                position_bias_masked = position_bias[:, mask.bool()]
            else:
                position_bias_masked = position_bias

            scores = scores + position_bias_masked

        #### NEW CODE ####
        # If there is no position bias, add attention mask to scores directly
        elif mask is not None:
            scores = scores + mask

        #### NEW CODE ####
        # Log scores to return for loss calculation
        scores_to_return = scores
        #### NEW CODE ####

        # Apply the mask from the delete gate
        if delete_gate_mask is not None:
            scores = scores + delete_gate_mask.squeeze(-1).unsqueeze(-2).unsqueeze(-2)

        if self.use_softmax1:
            attn_weights = softmax1(scores.float(), dim=-1).type_as(
                scores)
        else:
            attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
                scores
            )  # (batch_size, n_heads, seq_length, key_length)

        #### NEW CODE ####

        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        # (batch_size, seq_length, dim)
        attn_output = unshape(torch.matmul(attn_weights, value_states))
        attn_output = self.o(attn_output)

        present_key_value_state = (key_states, value_states) if (
            self.is_decoder and use_cache) else None
        outputs = (attn_output,) + \
            (present_key_value_state,) + (position_bias,)

        if output_attentions:
            attentions_keys_queries = (attn_weights, key_states, query_states, value_states, scores_to_return)
            outputs = outputs + (attentions_keys_queries,)

        return outputs


class MrT5LayerSelfAttention(nn.Module):
    """
    Modified version of T5LayerSelfAttention that uses MrT5Attention instead
    of T5Attention.
    """

    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        #### NEW CODE ####
        # Use MrT5Attention instead of T5Attention
        self.SelfAttention = MrT5Attention(
            config, has_relative_attention_bias=has_relative_attention_bias)
        #### NEW CODE ####
        self.layer_norm = T5LayerNorm(
            config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        #### NEW CODE ####
        delete_gate_mask=None,
        #### NEW CODE ####
    ):
        normed_hidden_states = self.layer_norm(hidden_states)
        attention_output = self.SelfAttention(
            normed_hidden_states,
            mask=attention_mask,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
            #### NEW CODE ####
            delete_gate_mask=delete_gate_mask,
            #### NEW CODE ####
        )
        hidden_states = hidden_states + self.dropout(attention_output[0])
        # add attentions if we output them
        outputs = (hidden_states,) + attention_output[1:]
        return outputs


class MrT5LayerCrossAttention(nn.Module):
    """
    Modified version of T5LayerCrossAttention that uses MrT5Attention instead
    of T5Attention.
    """

    def __init__(self, config):
        super().__init__()
        #### NEW CODE ####
        # Use MrT5Attention instead of T5Attention
        self.EncDecAttention = MrT5Attention(
            config, has_relative_attention_bias=False)
        #### NEW CODE ####
        self.layer_norm = T5LayerNorm(
            config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(
        self,
        hidden_states,
        key_value_states,
        attention_mask=None,
        position_bias=None,
        layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        query_length=None,
        output_attentions=False,
        #### NEW CODE ####
        delete_gate_mask=None,
        #### NEW CODE ####
    ):
        normed_hidden_states = self.layer_norm(hidden_states)
        attention_output = self.EncDecAttention(
            normed_hidden_states,
            mask=attention_mask,
            key_value_states=key_value_states,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=past_key_value,
            use_cache=use_cache,
            query_length=query_length,
            output_attentions=output_attentions,
            #### NEW CODE ####
            delete_gate_mask=delete_gate_mask,
            #### NEW CODE ####
        )
        layer_output = hidden_states + self.dropout(attention_output[0])
        # add attentions if we output them
        outputs = (layer_output,) + attention_output[1:]
        return outputs


class MrT5Block(nn.Module):
    """
    Modified version of T5Block that uses MrT5LayerSelfAttention and
    MrT5LayerCrossAttention instead of T5LayerSelfAttention and
    T5LayerCrossAttention.
    """

    def __init__(self, config, has_relative_attention_bias=False,
                 #### NEW CODE ####
                 has_delete_gate=False,
                 #### NEW CODE ####
                 ):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.layer = nn.ModuleList()
        #### NEW CODE ####
        # Use MrT5LayerSelfAttention and MrT5LayerCrossAttention
        # instead of T5LayerSelfAttention and T5LayerCrossAttention
        self.layer.append(MrT5LayerSelfAttention(
            config, has_relative_attention_bias=has_relative_attention_bias))
        if self.is_decoder:
            self.layer.append(MrT5LayerCrossAttention(config))
        #### NEW CODE ####

        self.layer.append(T5LayerFF(config))

        #### NEW CODE ####
        # Add delete gate if needed
        self.has_delete_gate = has_delete_gate
        if self.has_delete_gate:
            if config.deletion_type == "scaled_sigmoid":
                self.delete_gate = SigmoidDeleteGate(config)
            elif config.deletion_type == "log_sigmoid":
                self.delete_gate = LogSigmoidDeleteGate(config)
            elif config.deletion_type == "random":
                self.delete_gate = RandomDeleteGate(config)
            elif config.deletion_type == "fixed":
                self.delete_gate = FixedDeleteGate(config)
            else:
                raise ValueError(
                    f"Invalid deletion type: {config.deletion_type}")

        # Set hard_delete flags
        self.sigmoid_mask_scale = config.sigmoid_mask_scale
        self.deletion_threshold = config.deletion_threshold
        #### NEW CODE ####

    #### NEW CODE ####
    
    def __get_new_positions_and_mask(self, batch_size, seq_len, delete_gate_mask, deletion_threshold, device):
        delete_gate_mask = delete_gate_mask.squeeze(-1)

        # Create filter from delete gate mask
        deletion_threshold = deletion_threshold if deletion_threshold is not None else self.deletion_threshold
        keep_this = delete_gate_mask > deletion_threshold

        # Calculate the target position for each token
        target_pos = torch.cumsum(keep_this, dim=1) - 1
        new_len = target_pos[:, -1].max().item() + 1

        # Clamp the target position to avoid out of bounds when deleting everything
        target_pos = target_pos.clamp(min=0)

        # Map the positions to the src side. Do this in int32, because it's faster and we will not have sequences
        # longer than 2^31
        positions = torch.arange(seq_len, device=device, dtype=torch.int32).repeat(batch_size, 1)
        positions *= keep_this.int()

        src_side_pos = torch.zeros(batch_size, new_len, device=device, dtype=torch.int32)
        src_side_pos.scatter_add_(1, target_pos, positions)

        # Create the new mask
        new_mask = torch.arange(new_len, device=device).expand(batch_size, -1) <= target_pos[:, -1:]
        new_mask = (~new_mask).float() * -1e9
        new_mask = new_mask.unsqueeze(-1)

        return src_side_pos.long(), new_mask
    
    def __hard_delete_hidden_states(self, hidden_states, positions):
        new_hidden_states = torch.gather(hidden_states, 1, positions.unsqueeze(2).expand(-1, -1, hidden_states.size(2)))
        return new_hidden_states
    
    def __hard_delete_4_dimensions(self, position_bias, positions):
        new_position_bias = torch.gather(position_bias, 1, positions.unsqueeze(2).unsqueeze(3).expand(-1, -1, position_bias.size(2), position_bias.size(3)))
        return new_position_bias
    
    #### NEW CODE ####

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=True,
        #### NEW CODE ####
        delete_gate_mask=None,
        input_ids=None,
        hard_delete=None,
        deletion_threshold=None,
        #### NEW CODE ####
    ):
        if past_key_value is not None:
            if not self.is_decoder:
                logger.warning(
                    "`past_key_values` is passed to the encoder. Please make sure this is intended.")
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4

            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
                    f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
                    f"Got {len(past_key_value)} past key / value states"
                )

            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
        else:
            self_attn_past_key_value, cross_attn_past_key_value = None, None

        ##### NEW CODE #####
        # Initialize delete gate values and logits for logging/loss calculation
        delete_gate_values = None
        delete_gate_logits = None

        if self.has_delete_gate:
            delete_gate_values, delete_gate_logits = self.delete_gate(
                hidden_states, input_ids)
            delete_gate_mask = delete_gate_values

            # Raise error if all tokens are deleted in any sequence in batch
            if (delete_gate_values < self.deletion_threshold).all():
                raise ValueError("All tokens are deleted in this batch. " + \
                                 "Please adjust the deletion rate or " + \
                                 "alpha hyperparameter.")

            # Apply hard deletion
            if hard_delete:

                # Compute new token positions
                new_positions, delete_gate_mask = self.__get_new_positions_and_mask(
                    hidden_states.size(0), hidden_states.size(1), delete_gate_mask, deletion_threshold, hidden_states.device)

                # Compute new position bias
                if position_bias is not None:
                    new_position_bias = self.__hard_delete_4_dimensions(
                        position_bias.permute(0, 2, 3, 1), new_positions)
                    new_position_bias = self.__hard_delete_4_dimensions(
                        new_position_bias.permute(0, 2, 1, 3), new_positions)
                    position_bias = new_position_bias.permute(0, 3, 2, 1)

                # Compute new attention mask
                new_attention_mask = self.__hard_delete_4_dimensions(
                    attention_mask.permute(0, 3, 1, 2), new_positions)
                attention_mask = new_attention_mask.permute(0, 2, 3, 1)

                # Compute new hidden states and delete gate mask
                hidden_states = self.__hard_delete_hidden_states(
                    hidden_states, new_positions)

        ##### NEW CODE #####

        self_attention_outputs = self.layer[0](
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=self_attn_past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
            #### NEW CODE ####
            # Only apply delete_gate_mask to self-attention if the block
            # is the encoder
            delete_gate_mask=None if self.is_decoder else delete_gate_mask,
            #### NEW CODE ####
        )
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        # Keep self-attention outputs and relative position weights
        attention_outputs = self_attention_outputs[2:]

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.where(
                torch.isinf(hidden_states).any(),
                torch.finfo(hidden_states.dtype).max - 1000,
                torch.finfo(hidden_states.dtype).max,
            )
            hidden_states = torch.clamp(
                hidden_states, min=-clamp_value, max=clamp_value)

        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None

            cross_attention_outputs = self.layer[1](
                hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                query_length=query_length,
                use_cache=use_cache,
                output_attentions=output_attentions,
                #### NEW CODE ####
                delete_gate_mask=delete_gate_mask,
                #### NEW CODE ####
            )
            hidden_states = cross_attention_outputs[0]

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16:
                clamp_value = torch.where(
                    torch.isinf(hidden_states).any(),
                    torch.finfo(hidden_states.dtype).max - 1000,
                    torch.finfo(hidden_states.dtype).max,
                )
                hidden_states = torch.clamp(
                    hidden_states, min=-clamp_value, max=clamp_value)

            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + \
                    cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.where(
                torch.isinf(hidden_states).any(),
                torch.finfo(hidden_states.dtype).max - 1000,
                torch.finfo(hidden_states.dtype).max,
            )
            hidden_states = torch.clamp(
                hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

        ##### NEW CODE #####
        if self.has_delete_gate:
            outputs = outputs + \
                (delete_gate_values, delete_gate_logits, delete_gate_mask, attention_mask)
        ##### NEW CODE #####

        # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (delete_gate_mask), (delete_gate_logits)
        return outputs


class MrT5Stack(T5Stack):
    def __init__(self, config, embed_tokens=None):
        super().__init__(config, embed_tokens)

        ##### NEW CODE #####
        if self.is_decoder:
            self.block = nn.ModuleList(
                [
                    MrT5Block(
                        config, has_relative_attention_bias=bool(i == 0))
                    for i in range(config.num_layers)
                ]
            )
        else:
            blocks = []
            for i in range(config.num_layers):
                blocks.append(
                    MrT5Block(
                        config,
                        # Only the first layer has relative attention bias
                        has_relative_attention_bias=bool(i == 0),
                        # Add delete gate if specified
                        has_delete_gate=bool(i == config.delete_gate_layer),
                    )
                )
            self.block = nn.ModuleList(blocks)
        ##### NEW CODE #####

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        #### NEW CODE ####
        delete_gate_mask=None,
        delete_gate_output=None,
        delete_gate_logits=None,
        hard_delete=None,
        deletion_threshold=None,
        #### NEW CODE ####
    ):
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
                f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")

        if inputs_embeds is None:
            if self.embed_tokens is None:
                raise ValueError(
                    "You have to initialize the model with valid token embeddings")
            inputs_embeds = self.embed_tokens(input_ids)

        #### NEW CODE ####
        if self.absolute_pos_embed is not None:
            position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=inputs_embeds.device)
            position_embeds = self.absolute_pos_embed(position_ids)
            inputs_embeds = inputs_embeds + position_embeds
        #### NEW CODE ####

        batch_size, seq_length = input_shape

        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + \
            seq_length if past_key_values is not None else seq_length

        if use_cache is True:
            if not self.is_decoder:
                raise ValueError(
                    f"`use_cache` can only be set to `True` if {self} is used as a decoder")

        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)

        if attention_mask is None:
            attention_mask = torch.ones(
                batch_size, mask_seq_length, device=inputs_embeds.device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, input_shape)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (
                encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(
                    encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
                )
            encoder_extended_attention_mask = self.invert_attention_mask(
                encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        #### NEW CODE ####
        # Return a new encoder attention mask if hard delete is enabled
        attention_mask_to_return = None
        #### NEW CODE ####

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
        cross_attn_head_mask = self.get_head_mask(
            cross_attn_head_mask, self.config.num_layers)
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
        position_bias = None
        encoder_decoder_position_bias = None

        #### NEW CODE ####
        all_queries = () if output_attentions else None
        all_keys = () if output_attentions else None
        all_values = () if output_attentions else None
        all_scores = () if output_attentions else None
        all_cross_attn_queries = () if (output_attentions and self.is_decoder) else None
        all_cross_attn_keys = () if (output_attentions and self.is_decoder) else None
        all_cross_attn_values = () if (output_attentions and self.is_decoder) else None
        all_cross_attn_scores = () if (output_attentions and self.is_decoder) else None
        #### NEW CODE ####

        hidden_states = self.dropout(inputs_embeds)

        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
            layer_head_mask = head_mask[i]
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(
                        hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(
                        hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(
                        hidden_states.device)
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
                        hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.forward,
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                    use_cache,
                    output_attentions,
                    #### NEW CODE ####
                    delete_gate_mask,
                    #### NEW CODE ####
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    #### NEW CODE ####
                    delete_gate_mask=delete_gate_mask,
                    input_ids=input_ids,
                    hard_delete=hard_delete,
                    deletion_threshold=deletion_threshold,
                    #### NEW CODE ####
                )

            #### NEW CODE ####
            # Update delete_gate_mask if the previous layer had a delete gate
            if layer_module.has_delete_gate:
                delete_gate_output, delete_gate_logits, delete_gate_mask, new_attention_mask = layer_outputs[-4], layer_outputs[-3], layer_outputs[-2], layer_outputs[-1]

                # Update resized masks if the previous layer did a hard deletion
                if hard_delete:
                    extended_attention_mask = new_attention_mask
                    attention_mask_to_return = extended_attention_mask.squeeze(-2).squeeze(-2)
                    attention_mask_to_return = (attention_mask_to_return == 0).int()

            #### NEW CODE ####

            # layer_outputs is a tuple with:
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]

            hidden_states, present_key_value_state = layer_outputs[:2]

            # We share the position biases between the layers - the first layer store them
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                #### NEW CODE ####
                index = 4 if output_attentions else 3
                encoder_decoder_position_bias = layer_outputs[index]
                #### NEW CODE ####
            # append next layer key value states
            if use_cache:
                present_key_value_states = present_key_value_states + \
                    (present_key_value_state,)

            #### NEW CODE ####
            if output_attentions:
                attn_weights, keys, queries, values, scores = layer_outputs[3]
                all_attentions = all_attentions + (attn_weights,)
                all_queries = all_queries + (queries,)
                all_keys = all_keys + (keys,)
                all_values = all_values + (values,)
                all_scores = all_scores + (scores,)

                if self.is_decoder:
                    cross_attn_weights, cross_attn_keys, cross_attn_queries, \
                        cross_attn_values, cross_attn_scores = layer_outputs[5]
                    all_cross_attentions = all_cross_attentions + \
                        (cross_attn_weights,)
                    all_cross_attn_queries = all_cross_attn_queries + \
                        (cross_attn_queries,)
                    all_cross_attn_keys = all_cross_attn_keys + \
                        (cross_attn_keys,)
                    all_cross_attn_values = all_cross_attn_values + \
                        (cross_attn_values,)
                    all_cross_attn_scores = all_cross_attn_scores + \
                        (cross_attn_scores,)
            #### NEW CODE ####

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                    #### NEW CODE ####
                    delete_gate_mask,
                    delete_gate_output,
                    delete_gate_logits,
                    attention_mask_to_return,
                    all_queries,
                    all_keys,
                    all_values,
                    all_scores,
                    all_cross_attn_queries,
                    all_cross_attn_keys,
                    all_cross_attn_values,
                    all_cross_attn_scores,
                    #### NEW CODE ####
                ]
                if v is not None
            )

        return MrT5BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
            #### NEW CODE ####
            delete_gate_mask=delete_gate_mask,
            delete_gate_output=delete_gate_output,
            delete_gate_logits=delete_gate_logits,
            attention_mask=attention_mask_to_return,
            attention_queries=all_queries,
            attention_keys=all_keys,
            attention_values=all_values,
            attention_scores=all_scores,
            cross_attention_queries=all_cross_attn_queries,
            cross_attention_keys=all_cross_attn_keys,
            cross_attention_values=all_cross_attn_values,
            cross_attention_scores=all_cross_attn_scores,
            #### NEW CODE ####
        )


class MrT5ForConditionalGeneration(T5ForConditionalGeneration):
    
    config_class = MrT5Config

    def __init__(self, config: MrT5Config):
        super().__init__(config)
        #### NEW CODE ####
        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = MrT5Stack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = MrT5Stack(decoder_config, self.shared)
        #### NEW CODE ####

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        #### NEW CODE ####
        hard_delete: bool = False,
        deletion_threshold: Optional[float] = None,
        #### NEW CODE ####
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                #### NEW CODE ####
                hard_delete=hard_delete,
                deletion_threshold=deletion_threshold,
                #### NEW CODE ####
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            #### NEW CODE ####
            encoder_outputs = MrT5BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=encoder_outputs.last_hidden_state,
                hidden_states=encoder_outputs.hidden_states if 'hidden_states' in encoder_outputs else None,
                attentions=encoder_outputs.attentions if 'attentions' in encoder_outputs else None,
                delete_gate_mask=encoder_outputs.delete_gate_mask if 'delete_gate_mask' in encoder_outputs else None,
            )
            #### NEW CODE ####

        #### NEW CODE ####
        
        hidden_states = encoder_outputs.last_hidden_state
        attention_mask = encoder_outputs.attention_mask if 'attention_mask' in encoder_outputs else attention_mask
        
        #### NEW CODE ####

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(
                    self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(
                    self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            #### NEW CODE ####
            delete_gate_mask=encoder_outputs.delete_gate_mask,
            #### NEW CODE ####
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            # move labels to correct device to enable PP
            labels = labels.to(lm_logits.device)
            loss = loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        ##### NEW CODE #####
        return MrT5Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            delete_gate_mask=encoder_outputs.delete_gate_mask,
            delete_gate_output=encoder_outputs.delete_gate_output,
            delete_gate_logits=encoder_outputs.delete_gate_logits,
            encoder_keys=encoder_outputs.attention_keys,
            encoder_queries=encoder_outputs.attention_queries,
            encoder_values=encoder_outputs.attention_values,
            encoder_scores=encoder_outputs.attention_scores,
            decoder_keys=decoder_outputs.attention_keys,
            decoder_queries=decoder_outputs.attention_queries,
            decoder_values=decoder_outputs.attention_values,
            decoder_scores=decoder_outputs.attention_scores,
            cross_attention_queries=decoder_outputs.cross_attention_queries,
            cross_attention_keys=decoder_outputs.cross_attention_keys,
            cross_attention_values=decoder_outputs.cross_attention_values,
            cross_attention_scores=decoder_outputs.cross_attention_scores,
        )
        ##### NEW CODE #####

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        ##### NEW CODE #####
        # TODO: Generation will need special handling of attention masks, which
        # will need to be resized if hard delete is enabled. For now, we will
        # simply omit the encoder attention mask for generation.
        attention_mask = None
        ##### NEW CODE #####

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }