juliekallini commited on
Commit
3771cfd
·
verified ·
1 Parent(s): a40c637

Upload modeling_mrt5.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mrt5.py +1352 -0
modeling_mrt5.py ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_mrt5.py
2
+ # Author: Julie Kallini
3
+ # Description: This file contains the implementation of the MrT5 model.
4
+ # The code is adapted from HuggingFace's modeling_t5.py. New code sequences
5
+ # are labeled with comments.
6
+
7
+ import torch
8
+ import copy
9
+ import numpy as np
10
+ from torch import nn
11
+ from models.modeling_t5 import (
12
+ T5Attention,
13
+ T5LayerNorm,
14
+ T5LayerFF,
15
+ T5Stack,
16
+ T5ForConditionalGeneration,
17
+ softmax1,
18
+ )
19
+ from configuration_mrt5 import MrT5Config
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutput,
22
+ BaseModelOutputWithPastAndCrossAttentions,
23
+ Seq2SeqLMOutput,
24
+ )
25
+ from transformers.utils import logging
26
+ from typing import Optional, Tuple, Union
27
+ from dataclasses import dataclass
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ @dataclass
32
+ class MrT5BaseModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
33
+ delete_gate_mask: torch.FloatTensor = None
34
+ delete_gate_output: torch.FloatTensor = None
35
+ delete_gate_logits: torch.FloatTensor = None
36
+ attention_mask: torch.FloatTensor = None
37
+ attention_queries: torch.FloatTensor = None
38
+ attention_keys: torch.FloatTensor = None
39
+ attention_values: torch.FloatTensor = None
40
+ attention_scores: torch.FloatTensor = None
41
+ cross_attention_keys: torch.FloatTensor = None
42
+ cross_attention_queries: torch.FloatTensor = None
43
+ cross_attention_values: torch.FloatTensor = None
44
+ cross_attention_scores: torch.FloatTensor = None
45
+
46
+
47
+ @dataclass
48
+ class MrT5Seq2SeqLMOutput(Seq2SeqLMOutput):
49
+ delete_gate_mask: torch.FloatTensor = None
50
+ delete_gate_output: torch.FloatTensor = None
51
+ delete_gate_logits: torch.FloatTensor = None
52
+ encoder_keys: torch.FloatTensor = None
53
+ encoder_queries: torch.FloatTensor = None
54
+ encoder_values: torch.FloatTensor = None
55
+ encoder_scores: torch.FloatTensor = None
56
+ decoder_keys: torch.FloatTensor = None
57
+ decoder_queries: torch.FloatTensor = None
58
+ decoder_values: torch.FloatTensor = None
59
+ decoder_scores: torch.FloatTensor = None
60
+ cross_attention_keys: torch.FloatTensor = None
61
+ cross_attention_queries: torch.FloatTensor = None
62
+ cross_attention_values: torch.FloatTensor = None
63
+ cross_attention_scores: torch.FloatTensor = None
64
+
65
+
66
+ TORCH_INIT_FUNCTIONS = {
67
+ "uniform_": nn.init.uniform_,
68
+ "normal_": nn.init.normal_,
69
+ "trunc_normal_": nn.init.trunc_normal_,
70
+ "constant_": nn.init.constant_,
71
+ "xavier_uniform_": nn.init.xavier_uniform_,
72
+ "xavier_normal_": nn.init.xavier_normal_,
73
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
74
+ "kaiming_normal_": nn.init.kaiming_normal_,
75
+ "uniform": nn.init.uniform,
76
+ "normal": nn.init.normal,
77
+ "xavier_uniform": nn.init.xavier_uniform,
78
+ "xavier_normal": nn.init.xavier_normal,
79
+ "kaiming_uniform": nn.init.kaiming_uniform,
80
+ "kaiming_normal": nn.init.kaiming_normal,
81
+ }
82
+
83
+ class ScaledSigmoid(nn.Module):
84
+ def __init__(self, sigmoid_mask_scale):
85
+ super().__init__()
86
+ self.sigmoid_mask_scale = sigmoid_mask_scale
87
+
88
+ def forward(self, input):
89
+ return self.sigmoid_mask_scale * torch.sigmoid(-input)
90
+
91
+ def gumbel_noise_like(x: torch.Tensor) -> torch.Tensor:
92
+ eps = 3e-4 if x.dtype == torch.float16 else 1e-10
93
+ uniform = torch.empty_like(x).uniform_(eps, 1 - eps)
94
+ return - (- uniform.log()).log()
95
+
96
+ class SigmoidDeleteGate(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.has_layer_norm = config.gate_layer_norm
100
+ if self.has_layer_norm:
101
+ self.layer_norm = T5LayerNorm(config.hidden_size)
102
+ self.feed_forward = nn.Linear(config.hidden_size, 1)
103
+ self._init_weights(self.feed_forward)
104
+ self.activation = ScaledSigmoid(config.sigmoid_mask_scale)
105
+ self.use_gumbel_noise = config.use_gumbel_noise
106
+
107
+ def forward(self, hidden_states, input_ids):
108
+ if self.has_layer_norm:
109
+ hidden_states = self.layer_norm(hidden_states)
110
+ delete_gate_logits = self.feed_forward(hidden_states)
111
+
112
+ # Add gumbel noise to the delete gate logits
113
+ if self.training and self.use_gumbel_noise:
114
+ gumbel_noise = gumbel_noise_like(delete_gate_logits)
115
+ delete_gate_logits += gumbel_noise
116
+
117
+ gate_values = self.activation(delete_gate_logits)
118
+
119
+ # Check if there are any pad tokens in input_ids
120
+ if (input_ids == 0).any():
121
+ # Set gate values for pad tokens (input_ids == 0) to sigmoid_mask_scale
122
+ pad_mask = (input_ids == 0).unsqueeze(-1)
123
+ gate_values = torch.where(pad_mask, torch.tensor(self.activation.sigmoid_mask_scale), gate_values)
124
+
125
+ return gate_values, delete_gate_logits
126
+
127
+ def _init_weights(self, m, init_func="xavier_uniform_"):
128
+ # Initialize the weights. This is necessary because
129
+ # HuggingFace disables initialization during "from_pretrained"
130
+ if isinstance(m, nn.Linear):
131
+ TORCH_INIT_FUNCTIONS[init_func](m.weight)
132
+ m.bias.data.fill_(1)
133
+
134
+
135
+ class LogSigmoidDeleteGate(SigmoidDeleteGate):
136
+ def __init__(self, config):
137
+ super().__init__(config)
138
+ self.activation = nn.LogSigmoid()
139
+
140
+ class RandomDeleteGate(nn.Module):
141
+ def __init__(self, config):
142
+ super().__init__()
143
+ # Store the sigmoid_mask_scale and the probability of activation
144
+ self.sigmoid_mask_scale = config.sigmoid_mask_scale
145
+ self.random_deletion_probability = config.random_deletion_probability
146
+
147
+ def __random_mask_tensor(self, x, n):
148
+ # Determine the shape for the output tensor
149
+ target_shape = (x.shape[0], x.shape[1], 1)
150
+ total_elements = x.shape[0] * x.shape[1]
151
+
152
+ # Create a flattened float tensor of all 0.0
153
+ flat_tensor = torch.zeros(total_elements, dtype=torch.float32, device=x.device)
154
+
155
+ # Randomly select n indices to be set to 1.0
156
+ indices = torch.randperm(total_elements)[:n]
157
+ flat_tensor[indices] = 1.0
158
+
159
+ # Reshape it to match the desired target shape
160
+ float_tensor = flat_tensor.view(target_shape)
161
+
162
+ return float_tensor
163
+
164
+ def forward(self, hidden_states, input_ids):
165
+ # Calculate the number of tokens to delete using a gaussian
166
+ deletion_percentage = np.random.normal(loc=self.random_deletion_probability, scale=0.05)
167
+ n_deletions = int(deletion_percentage * hidden_states.shape[0] * hidden_states.shape[1])
168
+
169
+ # Create a random mask with n_deletions True values
170
+ random_mask = self.__random_mask_tensor(hidden_states, n_deletions)
171
+
172
+ # Scale the mask by sigmoid_mask_scale
173
+ delete_gate_mask = random_mask * self.sigmoid_mask_scale
174
+ return delete_gate_mask, delete_gate_mask
175
+
176
+
177
+ class FixedDeleteGate(nn.Module):
178
+ def __init__(self, config):
179
+ super().__init__()
180
+ self.sigmoid_mask_scale = config.sigmoid_mask_scale
181
+ self.fixed_deletion_amount = config.fixed_deletion_amount
182
+ self.sep_tokens = torch.tensor([12, 13, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
183
+ 46, 47, 48, 49, 50, 61, 62, 63, 64, 65, 66, 67, 94,
184
+ 95, 96, 97, 98, 99, 126, 127, 128, 129, 1])
185
+
186
+ def __create_mask(self, input_ids):
187
+ device = input_ids.device
188
+ batch_size, seq_len = input_ids.size()
189
+ self.sep_tokens = self.sep_tokens.to(device)
190
+
191
+ # Create an initial mask filled with sigmoid_mask_scale
192
+ mask = torch.full((batch_size, seq_len), self.sigmoid_mask_scale, device=device)
193
+
194
+ # Find sep_token indices
195
+ is_sep = torch.isin(input_ids, self.sep_tokens)
196
+
197
+ # Create a tensor of segment lengths
198
+ sep_positions = torch.cumsum(is_sep, dim=1)
199
+ segment_lengths = torch.zeros_like(input_ids, dtype=torch.float)
200
+ segment_lengths[:, 1:] = (sep_positions[:, 1:] != sep_positions[:, :-1]).float()
201
+ segment_lengths[:, 0] = 1.0
202
+ segment_lengths = torch.cumsum(segment_lengths, dim=1)
203
+
204
+ # Calculate number of zeros for each segment
205
+ segment_counts = torch.bincount(sep_positions.view(-1), minlength=seq_len)
206
+ segment_starts = torch.cumsum(torch.cat([torch.tensor([0], device=device), segment_counts[:-1]]), dim=0)
207
+ segment_ends = torch.cumsum(segment_counts, dim=0)
208
+ num_zeros = torch.ceil((1 - self.fixed_deletion_amount) * (segment_ends - segment_starts)).long()
209
+
210
+ # Create the mask based on the calculated number of zeros
211
+ for i in range(batch_size):
212
+ for start, count in zip(segment_starts, num_zeros):
213
+ mask[i, start:start + count] = 0
214
+
215
+ return mask.to(torch.float)
216
+
217
+ def forward(self, hidden_states, input_ids):
218
+ delete_gate_mask = self.__create_mask(input_ids).unsqueeze(-1)
219
+ return delete_gate_mask, delete_gate_mask
220
+
221
+
222
+ class MrT5Attention(T5Attention):
223
+ """
224
+ Extends the T5Attention class to include a delete gate. Only the forward
225
+ method is modified. The delete_gate_mask passed to the forward function
226
+ is applied to the attention scores.
227
+ """
228
+
229
+ def __init__(self, config: MrT5Config, has_relative_attention_bias=False):
230
+ super().__init__(config, has_relative_attention_bias)
231
+ #### NEW CODE ####
232
+ self.use_softmax1 = config.use_softmax1
233
+ #### NEW CODE ####
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states,
238
+ mask=None,
239
+ key_value_states=None,
240
+ position_bias=None,
241
+ past_key_value=None,
242
+ layer_head_mask=None,
243
+ query_length=None,
244
+ use_cache=False,
245
+ output_attentions=False,
246
+ #### NEW CODE ####
247
+ delete_gate_mask=None,
248
+ #### NEW CODE ####
249
+ ):
250
+ """
251
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
252
+ """
253
+ # Input is (batch_size, seq_length, dim)
254
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
255
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
256
+ batch_size, seq_length = hidden_states.shape[:2]
257
+
258
+ real_seq_length = seq_length
259
+
260
+ if past_key_value is not None:
261
+ if len(past_key_value) != 2:
262
+ raise ValueError(
263
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
264
+ )
265
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
266
+
267
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[
268
+ 1]
269
+
270
+ def shape(states):
271
+ """projection"""
272
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
273
+
274
+ def unshape(states):
275
+ """reshape"""
276
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
277
+
278
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
279
+ """projects hidden states correctly to key/query states"""
280
+ if key_value_states is None:
281
+ # self-attn
282
+ # (batch_size, n_heads, seq_length, dim_per_head)
283
+ hidden_states = shape(proj_layer(hidden_states))
284
+ elif past_key_value is None:
285
+ # cross-attn
286
+ # (batch_size, n_heads, seq_length, dim_per_head)
287
+ hidden_states = shape(proj_layer(key_value_states))
288
+
289
+ if past_key_value is not None:
290
+ if key_value_states is None:
291
+ # self-attn
292
+ # (batch_size, n_heads, key_length, dim_per_head)
293
+ hidden_states = torch.cat(
294
+ [past_key_value, hidden_states], dim=2)
295
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
296
+ # checking that the `sequence_length` of the `past_key_value` is the same as
297
+ # the provided `key_value_states` to support prefix tuning
298
+ # cross-attn
299
+ # (batch_size, n_heads, seq_length, dim_per_head)
300
+ hidden_states = shape(proj_layer(key_value_states))
301
+ else:
302
+ # cross-attn
303
+ hidden_states = past_key_value
304
+ return hidden_states
305
+
306
+ # get query states
307
+ # (batch_size, n_heads, seq_length, dim_per_head)
308
+ query_states = shape(self.q(hidden_states))
309
+
310
+ # get key/value states
311
+ key_states = project(
312
+ hidden_states, self.k, key_value_states, past_key_value[
313
+ 0] if past_key_value is not None else None
314
+ )
315
+ value_states = project(
316
+ hidden_states, self.v, key_value_states, past_key_value[
317
+ 1] if past_key_value is not None else None
318
+ )
319
+
320
+ # compute scores
321
+ scores = torch.matmul(
322
+ query_states, key_states.transpose(3, 2)
323
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
324
+
325
+ #### NEW CODE ####
326
+ if not self.has_absolute_position_embeddings:
327
+ #### NEW CODE ####
328
+ if position_bias is None:
329
+ if not self.has_relative_attention_bias:
330
+ position_bias = torch.zeros(
331
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
332
+ )
333
+ if self.gradient_checkpointing and self.training:
334
+ position_bias.requires_grad = True
335
+ else:
336
+ position_bias = self.compute_bias(
337
+ real_seq_length, key_length, device=scores.device)
338
+
339
+ # if key and values are already calculated
340
+ # we want only the last query position bias
341
+ if past_key_value is not None:
342
+ position_bias = position_bias[:, :, -hidden_states.size(1):, :]
343
+
344
+ if mask is not None:
345
+ # (batch_size, n_heads, seq_length, key_length)
346
+ position_bias = position_bias + mask
347
+
348
+ if self.pruned_heads:
349
+ mask = torch.ones(position_bias.shape[1])
350
+ mask[list(self.pruned_heads)] = 0
351
+ position_bias_masked = position_bias[:, mask.bool()]
352
+ else:
353
+ position_bias_masked = position_bias
354
+
355
+ scores = scores + position_bias_masked
356
+
357
+ #### NEW CODE ####
358
+ # If there is no position bias, add attention mask to scores directly
359
+ elif mask is not None:
360
+ scores = scores + mask
361
+
362
+ #### NEW CODE ####
363
+ # Log scores to return for loss calculation
364
+ scores_to_return = scores
365
+ #### NEW CODE ####
366
+
367
+ # Apply the mask from the delete gate
368
+ if delete_gate_mask is not None:
369
+ scores = scores + delete_gate_mask.squeeze(-1).unsqueeze(-2).unsqueeze(-2)
370
+
371
+ if self.use_softmax1:
372
+ attn_weights = softmax1(scores.float(), dim=-1).type_as(
373
+ scores)
374
+ else:
375
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
376
+ scores
377
+ ) # (batch_size, n_heads, seq_length, key_length)
378
+
379
+ #### NEW CODE ####
380
+
381
+ attn_weights = nn.functional.dropout(
382
+ attn_weights, p=self.dropout, training=self.training
383
+ ) # (batch_size, n_heads, seq_length, key_length)
384
+
385
+ # Mask heads if we want to
386
+ if layer_head_mask is not None:
387
+ attn_weights = attn_weights * layer_head_mask
388
+
389
+ # (batch_size, seq_length, dim)
390
+ attn_output = unshape(torch.matmul(attn_weights, value_states))
391
+ attn_output = self.o(attn_output)
392
+
393
+ present_key_value_state = (key_states, value_states) if (
394
+ self.is_decoder and use_cache) else None
395
+ outputs = (attn_output,) + \
396
+ (present_key_value_state,) + (position_bias,)
397
+
398
+ if output_attentions:
399
+ attentions_keys_queries = (attn_weights, key_states, query_states, value_states, scores_to_return)
400
+ outputs = outputs + (attentions_keys_queries,)
401
+
402
+ return outputs
403
+
404
+
405
+ class MrT5LayerSelfAttention(nn.Module):
406
+ """
407
+ Modified version of T5LayerSelfAttention that uses MrT5Attention instead
408
+ of T5Attention.
409
+ """
410
+
411
+ def __init__(self, config, has_relative_attention_bias=False):
412
+ super().__init__()
413
+ #### NEW CODE ####
414
+ # Use MrT5Attention instead of T5Attention
415
+ self.SelfAttention = MrT5Attention(
416
+ config, has_relative_attention_bias=has_relative_attention_bias)
417
+ #### NEW CODE ####
418
+ self.layer_norm = T5LayerNorm(
419
+ config.d_model, eps=config.layer_norm_epsilon)
420
+ self.dropout = nn.Dropout(config.dropout_rate)
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states,
425
+ attention_mask=None,
426
+ position_bias=None,
427
+ layer_head_mask=None,
428
+ past_key_value=None,
429
+ use_cache=False,
430
+ output_attentions=False,
431
+ #### NEW CODE ####
432
+ delete_gate_mask=None,
433
+ #### NEW CODE ####
434
+ ):
435
+ normed_hidden_states = self.layer_norm(hidden_states)
436
+ attention_output = self.SelfAttention(
437
+ normed_hidden_states,
438
+ mask=attention_mask,
439
+ position_bias=position_bias,
440
+ layer_head_mask=layer_head_mask,
441
+ past_key_value=past_key_value,
442
+ use_cache=use_cache,
443
+ output_attentions=output_attentions,
444
+ #### NEW CODE ####
445
+ delete_gate_mask=delete_gate_mask,
446
+ #### NEW CODE ####
447
+ )
448
+ hidden_states = hidden_states + self.dropout(attention_output[0])
449
+ # add attentions if we output them
450
+ outputs = (hidden_states,) + attention_output[1:]
451
+ return outputs
452
+
453
+
454
+ class MrT5LayerCrossAttention(nn.Module):
455
+ """
456
+ Modified version of T5LayerCrossAttention that uses MrT5Attention instead
457
+ of T5Attention.
458
+ """
459
+
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ #### NEW CODE ####
463
+ # Use MrT5Attention instead of T5Attention
464
+ self.EncDecAttention = MrT5Attention(
465
+ config, has_relative_attention_bias=False)
466
+ #### NEW CODE ####
467
+ self.layer_norm = T5LayerNorm(
468
+ config.d_model, eps=config.layer_norm_epsilon)
469
+ self.dropout = nn.Dropout(config.dropout_rate)
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states,
474
+ key_value_states,
475
+ attention_mask=None,
476
+ position_bias=None,
477
+ layer_head_mask=None,
478
+ past_key_value=None,
479
+ use_cache=False,
480
+ query_length=None,
481
+ output_attentions=False,
482
+ #### NEW CODE ####
483
+ delete_gate_mask=None,
484
+ #### NEW CODE ####
485
+ ):
486
+ normed_hidden_states = self.layer_norm(hidden_states)
487
+ attention_output = self.EncDecAttention(
488
+ normed_hidden_states,
489
+ mask=attention_mask,
490
+ key_value_states=key_value_states,
491
+ position_bias=position_bias,
492
+ layer_head_mask=layer_head_mask,
493
+ past_key_value=past_key_value,
494
+ use_cache=use_cache,
495
+ query_length=query_length,
496
+ output_attentions=output_attentions,
497
+ #### NEW CODE ####
498
+ delete_gate_mask=delete_gate_mask,
499
+ #### NEW CODE ####
500
+ )
501
+ layer_output = hidden_states + self.dropout(attention_output[0])
502
+ # add attentions if we output them
503
+ outputs = (layer_output,) + attention_output[1:]
504
+ return outputs
505
+
506
+
507
+ class MrT5Block(nn.Module):
508
+ """
509
+ Modified version of T5Block that uses MrT5LayerSelfAttention and
510
+ MrT5LayerCrossAttention instead of T5LayerSelfAttention and
511
+ T5LayerCrossAttention.
512
+ """
513
+
514
+ def __init__(self, config, has_relative_attention_bias=False,
515
+ #### NEW CODE ####
516
+ has_delete_gate=False,
517
+ #### NEW CODE ####
518
+ ):
519
+ super().__init__()
520
+ self.is_decoder = config.is_decoder
521
+ self.layer = nn.ModuleList()
522
+ #### NEW CODE ####
523
+ # Use MrT5LayerSelfAttention and MrT5LayerCrossAttention
524
+ # instead of T5LayerSelfAttention and T5LayerCrossAttention
525
+ self.layer.append(MrT5LayerSelfAttention(
526
+ config, has_relative_attention_bias=has_relative_attention_bias))
527
+ if self.is_decoder:
528
+ self.layer.append(MrT5LayerCrossAttention(config))
529
+ #### NEW CODE ####
530
+
531
+ self.layer.append(T5LayerFF(config))
532
+
533
+ #### NEW CODE ####
534
+ # Add delete gate if needed
535
+ self.has_delete_gate = has_delete_gate
536
+ if self.has_delete_gate:
537
+ if config.deletion_type == "scaled_sigmoid":
538
+ self.delete_gate = SigmoidDeleteGate(config)
539
+ elif config.deletion_type == "log_sigmoid":
540
+ self.delete_gate = LogSigmoidDeleteGate(config)
541
+ elif config.deletion_type == "random":
542
+ self.delete_gate = RandomDeleteGate(config)
543
+ elif config.deletion_type == "fixed":
544
+ self.delete_gate = FixedDeleteGate(config)
545
+ else:
546
+ raise ValueError(
547
+ f"Invalid deletion type: {config.deletion_type}")
548
+
549
+ # Set hard_delete flags
550
+ self.sigmoid_mask_scale = config.sigmoid_mask_scale
551
+ self.deletion_threshold = config.deletion_threshold
552
+ #### NEW CODE ####
553
+
554
+ #### NEW CODE ####
555
+
556
+ def __get_new_positions_and_mask(self, batch_size, seq_len, delete_gate_mask, deletion_threshold, device):
557
+ delete_gate_mask = delete_gate_mask.squeeze(-1)
558
+
559
+ # Create filter from delete gate mask
560
+ deletion_threshold = deletion_threshold if deletion_threshold is not None else self.deletion_threshold
561
+ keep_this = delete_gate_mask > deletion_threshold
562
+
563
+ # Calculate the target position for each token
564
+ target_pos = torch.cumsum(keep_this, dim=1) - 1
565
+ new_len = target_pos[:, -1].max().item() + 1
566
+
567
+ # Clamp the target position to avoid out of bounds when deleting everything
568
+ target_pos = target_pos.clamp(min=0)
569
+
570
+ # Map the positions to the src side. Do this in int32, because it's faster and we will not have sequences
571
+ # longer than 2^31
572
+ positions = torch.arange(seq_len, device=device, dtype=torch.int32).repeat(batch_size, 1)
573
+ positions *= keep_this.int()
574
+
575
+ src_side_pos = torch.zeros(batch_size, new_len, device=device, dtype=torch.int32)
576
+ src_side_pos.scatter_add_(1, target_pos, positions)
577
+
578
+ # Create the new mask
579
+ new_mask = torch.arange(new_len, device=device).expand(batch_size, -1) <= target_pos[:, -1:]
580
+ new_mask = (~new_mask).float() * -1e9
581
+ new_mask = new_mask.unsqueeze(-1)
582
+
583
+ return src_side_pos.long(), new_mask
584
+
585
+ def __hard_delete_hidden_states(self, hidden_states, positions):
586
+ new_hidden_states = torch.gather(hidden_states, 1, positions.unsqueeze(2).expand(-1, -1, hidden_states.size(2)))
587
+ return new_hidden_states
588
+
589
+ def __hard_delete_4_dimensions(self, position_bias, positions):
590
+ new_position_bias = torch.gather(position_bias, 1, positions.unsqueeze(2).unsqueeze(3).expand(-1, -1, position_bias.size(2), position_bias.size(3)))
591
+ return new_position_bias
592
+
593
+ #### NEW CODE ####
594
+
595
+ def forward(
596
+ self,
597
+ hidden_states,
598
+ attention_mask=None,
599
+ position_bias=None,
600
+ encoder_hidden_states=None,
601
+ encoder_attention_mask=None,
602
+ encoder_decoder_position_bias=None,
603
+ layer_head_mask=None,
604
+ cross_attn_layer_head_mask=None,
605
+ past_key_value=None,
606
+ use_cache=False,
607
+ output_attentions=False,
608
+ return_dict=True,
609
+ #### NEW CODE ####
610
+ delete_gate_mask=None,
611
+ input_ids=None,
612
+ hard_delete=None,
613
+ deletion_threshold=None,
614
+ #### NEW CODE ####
615
+ ):
616
+ if past_key_value is not None:
617
+ if not self.is_decoder:
618
+ logger.warning(
619
+ "`past_key_values` is passed to the encoder. Please make sure this is intended.")
620
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
621
+
622
+ if len(past_key_value) != expected_num_past_key_values:
623
+ raise ValueError(
624
+ f"There should be {expected_num_past_key_values} past states. "
625
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
626
+ f"Got {len(past_key_value)} past key / value states"
627
+ )
628
+
629
+ self_attn_past_key_value = past_key_value[:2]
630
+ cross_attn_past_key_value = past_key_value[2:]
631
+ else:
632
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
633
+
634
+ ##### NEW CODE #####
635
+ # Initialize delete gate values and logits for logging/loss calculation
636
+ delete_gate_values = None
637
+ delete_gate_logits = None
638
+
639
+ if self.has_delete_gate:
640
+ delete_gate_values, delete_gate_logits = self.delete_gate(
641
+ hidden_states, input_ids)
642
+ delete_gate_mask = delete_gate_values
643
+
644
+ # Raise error if all tokens are deleted in any sequence in batch
645
+ if (delete_gate_values < self.deletion_threshold).all():
646
+ raise ValueError("All tokens are deleted in this batch. " + \
647
+ "Please adjust the deletion rate or " + \
648
+ "alpha hyperparameter.")
649
+
650
+ # Apply hard deletion
651
+ if hard_delete:
652
+
653
+ # Compute new token positions
654
+ new_positions, delete_gate_mask = self.__get_new_positions_and_mask(
655
+ hidden_states.size(0), hidden_states.size(1), delete_gate_mask, deletion_threshold, hidden_states.device)
656
+
657
+ # Compute new position bias
658
+ if position_bias is not None:
659
+ new_position_bias = self.__hard_delete_4_dimensions(
660
+ position_bias.permute(0, 2, 3, 1), new_positions)
661
+ new_position_bias = self.__hard_delete_4_dimensions(
662
+ new_position_bias.permute(0, 2, 1, 3), new_positions)
663
+ position_bias = new_position_bias.permute(0, 3, 2, 1)
664
+
665
+ # Compute new attention mask
666
+ new_attention_mask = self.__hard_delete_4_dimensions(
667
+ attention_mask.permute(0, 3, 1, 2), new_positions)
668
+ attention_mask = new_attention_mask.permute(0, 2, 3, 1)
669
+
670
+ # Compute new hidden states and delete gate mask
671
+ hidden_states = self.__hard_delete_hidden_states(
672
+ hidden_states, new_positions)
673
+
674
+ ##### NEW CODE #####
675
+
676
+ self_attention_outputs = self.layer[0](
677
+ hidden_states,
678
+ attention_mask=attention_mask,
679
+ position_bias=position_bias,
680
+ layer_head_mask=layer_head_mask,
681
+ past_key_value=self_attn_past_key_value,
682
+ use_cache=use_cache,
683
+ output_attentions=output_attentions,
684
+ #### NEW CODE ####
685
+ # Only apply delete_gate_mask to self-attention if the block
686
+ # is the encoder
687
+ delete_gate_mask=None if self.is_decoder else delete_gate_mask,
688
+ #### NEW CODE ####
689
+ )
690
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
691
+ # Keep self-attention outputs and relative position weights
692
+ attention_outputs = self_attention_outputs[2:]
693
+
694
+ # clamp inf values to enable fp16 training
695
+ if hidden_states.dtype == torch.float16:
696
+ clamp_value = torch.where(
697
+ torch.isinf(hidden_states).any(),
698
+ torch.finfo(hidden_states.dtype).max - 1000,
699
+ torch.finfo(hidden_states.dtype).max,
700
+ )
701
+ hidden_states = torch.clamp(
702
+ hidden_states, min=-clamp_value, max=clamp_value)
703
+
704
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
705
+ if do_cross_attention:
706
+ # the actual query length is unknown for cross attention
707
+ # if using past key value states. Need to inject it here
708
+ if present_key_value_state is not None:
709
+ query_length = present_key_value_state[0].shape[2]
710
+ else:
711
+ query_length = None
712
+
713
+ cross_attention_outputs = self.layer[1](
714
+ hidden_states,
715
+ key_value_states=encoder_hidden_states,
716
+ attention_mask=encoder_attention_mask,
717
+ position_bias=encoder_decoder_position_bias,
718
+ layer_head_mask=cross_attn_layer_head_mask,
719
+ past_key_value=cross_attn_past_key_value,
720
+ query_length=query_length,
721
+ use_cache=use_cache,
722
+ output_attentions=output_attentions,
723
+ #### NEW CODE ####
724
+ delete_gate_mask=delete_gate_mask,
725
+ #### NEW CODE ####
726
+ )
727
+ hidden_states = cross_attention_outputs[0]
728
+
729
+ # clamp inf values to enable fp16 training
730
+ if hidden_states.dtype == torch.float16:
731
+ clamp_value = torch.where(
732
+ torch.isinf(hidden_states).any(),
733
+ torch.finfo(hidden_states.dtype).max - 1000,
734
+ torch.finfo(hidden_states.dtype).max,
735
+ )
736
+ hidden_states = torch.clamp(
737
+ hidden_states, min=-clamp_value, max=clamp_value)
738
+
739
+ # Combine self attn and cross attn key value states
740
+ if present_key_value_state is not None:
741
+ present_key_value_state = present_key_value_state + \
742
+ cross_attention_outputs[1]
743
+
744
+ # Keep cross-attention outputs and relative position weights
745
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
746
+
747
+ # Apply Feed Forward layer
748
+ hidden_states = self.layer[-1](hidden_states)
749
+
750
+ # clamp inf values to enable fp16 training
751
+ if hidden_states.dtype == torch.float16:
752
+ clamp_value = torch.where(
753
+ torch.isinf(hidden_states).any(),
754
+ torch.finfo(hidden_states.dtype).max - 1000,
755
+ torch.finfo(hidden_states.dtype).max,
756
+ )
757
+ hidden_states = torch.clamp(
758
+ hidden_states, min=-clamp_value, max=clamp_value)
759
+
760
+ outputs = (hidden_states,)
761
+
762
+ if use_cache:
763
+ outputs = outputs + (present_key_value_state,) + attention_outputs
764
+ else:
765
+ outputs = outputs + attention_outputs
766
+
767
+ ##### NEW CODE #####
768
+ if self.has_delete_gate:
769
+ outputs = outputs + \
770
+ (delete_gate_values, delete_gate_logits, delete_gate_mask, attention_mask)
771
+ ##### NEW CODE #####
772
+
773
+ # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (delete_gate_mask), (delete_gate_logits)
774
+ return outputs
775
+
776
+
777
+ class MrT5Stack(T5Stack):
778
+ def __init__(self, config, embed_tokens=None):
779
+ super().__init__(config, embed_tokens)
780
+
781
+ ##### NEW CODE #####
782
+ if self.is_decoder:
783
+ self.block = nn.ModuleList(
784
+ [
785
+ MrT5Block(
786
+ config, has_relative_attention_bias=bool(i == 0))
787
+ for i in range(config.num_layers)
788
+ ]
789
+ )
790
+ else:
791
+ blocks = []
792
+ for i in range(config.num_layers):
793
+ blocks.append(
794
+ MrT5Block(
795
+ config,
796
+ # Only the first layer has relative attention bias
797
+ has_relative_attention_bias=bool(i == 0),
798
+ # Add delete gate if specified
799
+ has_delete_gate=bool(i == config.delete_gate_layer),
800
+ )
801
+ )
802
+ self.block = nn.ModuleList(blocks)
803
+ ##### NEW CODE #####
804
+
805
+ def forward(
806
+ self,
807
+ input_ids=None,
808
+ attention_mask=None,
809
+ encoder_hidden_states=None,
810
+ encoder_attention_mask=None,
811
+ inputs_embeds=None,
812
+ head_mask=None,
813
+ cross_attn_head_mask=None,
814
+ past_key_values=None,
815
+ use_cache=None,
816
+ output_attentions=None,
817
+ output_hidden_states=None,
818
+ return_dict=None,
819
+ #### NEW CODE ####
820
+ delete_gate_mask=None,
821
+ delete_gate_output=None,
822
+ delete_gate_logits=None,
823
+ hard_delete=None,
824
+ deletion_threshold=None,
825
+ #### NEW CODE ####
826
+ ):
827
+ # Model parallel
828
+ if self.model_parallel:
829
+ torch.cuda.set_device(self.first_device)
830
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
831
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
832
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
833
+ output_hidden_states = (
834
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
835
+ )
836
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
837
+
838
+ if input_ids is not None and inputs_embeds is not None:
839
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
840
+ raise ValueError(
841
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
842
+ )
843
+ elif input_ids is not None:
844
+ input_shape = input_ids.size()
845
+ input_ids = input_ids.view(-1, input_shape[-1])
846
+ elif inputs_embeds is not None:
847
+ input_shape = inputs_embeds.size()[:-1]
848
+ else:
849
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
850
+ raise ValueError(
851
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
852
+
853
+ if inputs_embeds is None:
854
+ if self.embed_tokens is None:
855
+ raise ValueError(
856
+ "You have to initialize the model with valid token embeddings")
857
+ inputs_embeds = self.embed_tokens(input_ids)
858
+
859
+ #### NEW CODE ####
860
+ if self.absolute_pos_embed is not None:
861
+ position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=inputs_embeds.device)
862
+ position_embeds = self.absolute_pos_embed(position_ids)
863
+ inputs_embeds = inputs_embeds + position_embeds
864
+ #### NEW CODE ####
865
+
866
+ batch_size, seq_length = input_shape
867
+
868
+ # required mask seq length can be calculated via length of past
869
+ mask_seq_length = past_key_values[0][0].shape[2] + \
870
+ seq_length if past_key_values is not None else seq_length
871
+
872
+ if use_cache is True:
873
+ if not self.is_decoder:
874
+ raise ValueError(
875
+ f"`use_cache` can only be set to `True` if {self} is used as a decoder")
876
+
877
+ # initialize past_key_values with `None` if past does not exist
878
+ if past_key_values is None:
879
+ past_key_values = [None] * len(self.block)
880
+
881
+ if attention_mask is None:
882
+ attention_mask = torch.ones(
883
+ batch_size, mask_seq_length, device=inputs_embeds.device)
884
+
885
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
886
+ # ourselves in which case we just need to make it broadcastable to all heads.
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask, input_shape)
889
+
890
+ # If a 2D or 3D attention mask is provided for the cross-attention
891
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
892
+ if self.is_decoder and encoder_hidden_states is not None:
893
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
894
+ encoder_hidden_shape = (
895
+ encoder_batch_size, encoder_sequence_length)
896
+ if encoder_attention_mask is None:
897
+ encoder_attention_mask = torch.ones(
898
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
899
+ )
900
+ encoder_extended_attention_mask = self.invert_attention_mask(
901
+ encoder_attention_mask)
902
+ else:
903
+ encoder_extended_attention_mask = None
904
+
905
+ if self.gradient_checkpointing and self.training:
906
+ if use_cache:
907
+ logger.warning_once(
908
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
909
+ )
910
+ use_cache = False
911
+
912
+ #### NEW CODE ####
913
+ # Return a new encoder attention mask if hard delete is enabled
914
+ attention_mask_to_return = None
915
+ #### NEW CODE ####
916
+
917
+ # Prepare head mask if needed
918
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
919
+ cross_attn_head_mask = self.get_head_mask(
920
+ cross_attn_head_mask, self.config.num_layers)
921
+ present_key_value_states = () if use_cache else None
922
+ all_hidden_states = () if output_hidden_states else None
923
+ all_attentions = () if output_attentions else None
924
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
925
+ position_bias = None
926
+ encoder_decoder_position_bias = None
927
+
928
+ #### NEW CODE ####
929
+ all_queries = () if output_attentions else None
930
+ all_keys = () if output_attentions else None
931
+ all_values = () if output_attentions else None
932
+ all_scores = () if output_attentions else None
933
+ all_cross_attn_queries = () if (output_attentions and self.is_decoder) else None
934
+ all_cross_attn_keys = () if (output_attentions and self.is_decoder) else None
935
+ all_cross_attn_values = () if (output_attentions and self.is_decoder) else None
936
+ all_cross_attn_scores = () if (output_attentions and self.is_decoder) else None
937
+ #### NEW CODE ####
938
+
939
+ hidden_states = self.dropout(inputs_embeds)
940
+
941
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
942
+ layer_head_mask = head_mask[i]
943
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
944
+ # Model parallel
945
+ if self.model_parallel:
946
+ torch.cuda.set_device(hidden_states.device)
947
+ # Ensure that attention_mask is always on the same device as hidden_states
948
+ if attention_mask is not None:
949
+ attention_mask = attention_mask.to(hidden_states.device)
950
+ if position_bias is not None:
951
+ position_bias = position_bias.to(hidden_states.device)
952
+ if encoder_hidden_states is not None:
953
+ encoder_hidden_states = encoder_hidden_states.to(
954
+ hidden_states.device)
955
+ if encoder_extended_attention_mask is not None:
956
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(
957
+ hidden_states.device)
958
+ if encoder_decoder_position_bias is not None:
959
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
960
+ hidden_states.device)
961
+ if layer_head_mask is not None:
962
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
963
+ if cross_attn_layer_head_mask is not None:
964
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
965
+ hidden_states.device)
966
+ if output_hidden_states:
967
+ all_hidden_states = all_hidden_states + (hidden_states,)
968
+
969
+ if self.gradient_checkpointing and self.training:
970
+ layer_outputs = self._gradient_checkpointing_func(
971
+ layer_module.forward,
972
+ hidden_states,
973
+ extended_attention_mask,
974
+ position_bias,
975
+ encoder_hidden_states,
976
+ encoder_extended_attention_mask,
977
+ encoder_decoder_position_bias,
978
+ layer_head_mask,
979
+ cross_attn_layer_head_mask,
980
+ None, # past_key_value is always None with gradient checkpointing
981
+ use_cache,
982
+ output_attentions,
983
+ #### NEW CODE ####
984
+ delete_gate_mask,
985
+ #### NEW CODE ####
986
+ )
987
+ else:
988
+ layer_outputs = layer_module(
989
+ hidden_states,
990
+ attention_mask=extended_attention_mask,
991
+ position_bias=position_bias,
992
+ encoder_hidden_states=encoder_hidden_states,
993
+ encoder_attention_mask=encoder_extended_attention_mask,
994
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
995
+ layer_head_mask=layer_head_mask,
996
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
997
+ past_key_value=past_key_value,
998
+ use_cache=use_cache,
999
+ output_attentions=output_attentions,
1000
+ #### NEW CODE ####
1001
+ delete_gate_mask=delete_gate_mask,
1002
+ input_ids=input_ids,
1003
+ hard_delete=hard_delete,
1004
+ deletion_threshold=deletion_threshold,
1005
+ #### NEW CODE ####
1006
+ )
1007
+
1008
+ #### NEW CODE ####
1009
+ # Update delete_gate_mask if the previous layer had a delete gate
1010
+ if layer_module.has_delete_gate:
1011
+ delete_gate_output, delete_gate_logits, delete_gate_mask, new_attention_mask = layer_outputs[-4], layer_outputs[-3], layer_outputs[-2], layer_outputs[-1]
1012
+
1013
+ # Update resized masks if the previous layer did a hard deletion
1014
+ if hard_delete:
1015
+ extended_attention_mask = new_attention_mask
1016
+ attention_mask_to_return = extended_attention_mask.squeeze(-2).squeeze(-2)
1017
+ attention_mask_to_return = (attention_mask_to_return == 0).int()
1018
+
1019
+ #### NEW CODE ####
1020
+
1021
+ # layer_outputs is a tuple with:
1022
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1023
+ if use_cache is False:
1024
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1025
+
1026
+ hidden_states, present_key_value_state = layer_outputs[:2]
1027
+
1028
+ # We share the position biases between the layers - the first layer store them
1029
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1030
+ # (cross-attention position bias), (cross-attention weights)
1031
+ position_bias = layer_outputs[2]
1032
+ if self.is_decoder and encoder_hidden_states is not None:
1033
+ #### NEW CODE ####
1034
+ index = 4 if output_attentions else 3
1035
+ encoder_decoder_position_bias = layer_outputs[index]
1036
+ #### NEW CODE ####
1037
+ # append next layer key value states
1038
+ if use_cache:
1039
+ present_key_value_states = present_key_value_states + \
1040
+ (present_key_value_state,)
1041
+
1042
+ #### NEW CODE ####
1043
+ if output_attentions:
1044
+ attn_weights, keys, queries, values, scores = layer_outputs[3]
1045
+ all_attentions = all_attentions + (attn_weights,)
1046
+ all_queries = all_queries + (queries,)
1047
+ all_keys = all_keys + (keys,)
1048
+ all_values = all_values + (values,)
1049
+ all_scores = all_scores + (scores,)
1050
+
1051
+ if self.is_decoder:
1052
+ cross_attn_weights, cross_attn_keys, cross_attn_queries, \
1053
+ cross_attn_values, cross_attn_scores = layer_outputs[5]
1054
+ all_cross_attentions = all_cross_attentions + \
1055
+ (cross_attn_weights,)
1056
+ all_cross_attn_queries = all_cross_attn_queries + \
1057
+ (cross_attn_queries,)
1058
+ all_cross_attn_keys = all_cross_attn_keys + \
1059
+ (cross_attn_keys,)
1060
+ all_cross_attn_values = all_cross_attn_values + \
1061
+ (cross_attn_values,)
1062
+ all_cross_attn_scores = all_cross_attn_scores + \
1063
+ (cross_attn_scores,)
1064
+ #### NEW CODE ####
1065
+
1066
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1067
+ if self.model_parallel:
1068
+ for k, v in self.device_map.items():
1069
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1070
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1071
+
1072
+ hidden_states = self.final_layer_norm(hidden_states)
1073
+ hidden_states = self.dropout(hidden_states)
1074
+
1075
+ # Add last layer
1076
+ if output_hidden_states:
1077
+ all_hidden_states = all_hidden_states + (hidden_states,)
1078
+
1079
+ if not return_dict:
1080
+ return tuple(
1081
+ v
1082
+ for v in [
1083
+ hidden_states,
1084
+ present_key_value_states,
1085
+ all_hidden_states,
1086
+ all_attentions,
1087
+ all_cross_attentions,
1088
+ #### NEW CODE ####
1089
+ delete_gate_mask,
1090
+ delete_gate_output,
1091
+ delete_gate_logits,
1092
+ attention_mask_to_return,
1093
+ all_queries,
1094
+ all_keys,
1095
+ all_values,
1096
+ all_scores,
1097
+ all_cross_attn_queries,
1098
+ all_cross_attn_keys,
1099
+ all_cross_attn_values,
1100
+ all_cross_attn_scores,
1101
+ #### NEW CODE ####
1102
+ ]
1103
+ if v is not None
1104
+ )
1105
+
1106
+ return MrT5BaseModelOutputWithPastAndCrossAttentions(
1107
+ last_hidden_state=hidden_states,
1108
+ past_key_values=present_key_value_states,
1109
+ hidden_states=all_hidden_states,
1110
+ attentions=all_attentions,
1111
+ cross_attentions=all_cross_attentions,
1112
+ #### NEW CODE ####
1113
+ delete_gate_mask=delete_gate_mask,
1114
+ delete_gate_output=delete_gate_output,
1115
+ delete_gate_logits=delete_gate_logits,
1116
+ attention_mask=attention_mask_to_return,
1117
+ attention_queries=all_queries,
1118
+ attention_keys=all_keys,
1119
+ attention_values=all_values,
1120
+ attention_scores=all_scores,
1121
+ cross_attention_queries=all_cross_attn_queries,
1122
+ cross_attention_keys=all_cross_attn_keys,
1123
+ cross_attention_values=all_cross_attn_values,
1124
+ cross_attention_scores=all_cross_attn_scores,
1125
+ #### NEW CODE ####
1126
+ )
1127
+
1128
+
1129
+ class MrT5ForConditionalGeneration(T5ForConditionalGeneration):
1130
+
1131
+ config_class = MrT5Config
1132
+
1133
+ def __init__(self, config: MrT5Config):
1134
+ super().__init__(config)
1135
+ #### NEW CODE ####
1136
+ encoder_config = copy.deepcopy(config)
1137
+ encoder_config.is_decoder = False
1138
+ encoder_config.use_cache = False
1139
+ encoder_config.is_encoder_decoder = False
1140
+ self.encoder = MrT5Stack(encoder_config, self.shared)
1141
+
1142
+ decoder_config = copy.deepcopy(config)
1143
+ decoder_config.is_decoder = True
1144
+ decoder_config.is_encoder_decoder = False
1145
+ decoder_config.num_layers = config.num_decoder_layers
1146
+ self.decoder = MrT5Stack(decoder_config, self.shared)
1147
+ #### NEW CODE ####
1148
+
1149
+ def forward(
1150
+ self,
1151
+ input_ids: Optional[torch.LongTensor] = None,
1152
+ attention_mask: Optional[torch.FloatTensor] = None,
1153
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1154
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1155
+ head_mask: Optional[torch.FloatTensor] = None,
1156
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1157
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1158
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1159
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1160
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1161
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1162
+ labels: Optional[torch.LongTensor] = None,
1163
+ use_cache: Optional[bool] = None,
1164
+ output_attentions: Optional[bool] = None,
1165
+ output_hidden_states: Optional[bool] = None,
1166
+ return_dict: Optional[bool] = None,
1167
+ #### NEW CODE ####
1168
+ hard_delete: bool = False,
1169
+ deletion_threshold: Optional[float] = None,
1170
+ #### NEW CODE ####
1171
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1172
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1173
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1174
+
1175
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1176
+ if head_mask is not None and decoder_head_mask is None:
1177
+ if self.config.num_layers == self.config.num_decoder_layers:
1178
+ decoder_head_mask = head_mask
1179
+
1180
+ # Encode if needed (training, first prediction pass)
1181
+ if encoder_outputs is None:
1182
+ # Convert encoder inputs in embeddings if needed
1183
+ encoder_outputs = self.encoder(
1184
+ input_ids=input_ids,
1185
+ attention_mask=attention_mask,
1186
+ inputs_embeds=inputs_embeds,
1187
+ head_mask=head_mask,
1188
+ output_attentions=output_attentions,
1189
+ output_hidden_states=output_hidden_states,
1190
+ return_dict=return_dict,
1191
+ #### NEW CODE ####
1192
+ hard_delete=hard_delete,
1193
+ deletion_threshold=deletion_threshold,
1194
+ #### NEW CODE ####
1195
+ )
1196
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1197
+ #### NEW CODE ####
1198
+ encoder_outputs = MrT5BaseModelOutputWithPastAndCrossAttentions(
1199
+ last_hidden_state=encoder_outputs.last_hidden_state,
1200
+ hidden_states=encoder_outputs.hidden_states if 'hidden_states' in encoder_outputs else None,
1201
+ attentions=encoder_outputs.attentions if 'attentions' in encoder_outputs else None,
1202
+ delete_gate_mask=encoder_outputs.delete_gate_mask if 'delete_gate_mask' in encoder_outputs else None,
1203
+ )
1204
+ #### NEW CODE ####
1205
+
1206
+ #### NEW CODE ####
1207
+
1208
+ hidden_states = encoder_outputs.last_hidden_state
1209
+ attention_mask = encoder_outputs.attention_mask if 'attention_mask' in encoder_outputs else attention_mask
1210
+
1211
+ #### NEW CODE ####
1212
+
1213
+ if self.model_parallel:
1214
+ torch.cuda.set_device(self.decoder.first_device)
1215
+
1216
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1217
+ # get decoder inputs from shifting lm labels to the right
1218
+ decoder_input_ids = self._shift_right(labels)
1219
+
1220
+ # Set device for model parallelism
1221
+ if self.model_parallel:
1222
+ torch.cuda.set_device(self.decoder.first_device)
1223
+ hidden_states = hidden_states.to(self.decoder.first_device)
1224
+ if decoder_input_ids is not None:
1225
+ decoder_input_ids = decoder_input_ids.to(
1226
+ self.decoder.first_device)
1227
+ if attention_mask is not None:
1228
+ attention_mask = attention_mask.to(self.decoder.first_device)
1229
+ if decoder_attention_mask is not None:
1230
+ decoder_attention_mask = decoder_attention_mask.to(
1231
+ self.decoder.first_device)
1232
+
1233
+ # Decode
1234
+ decoder_outputs = self.decoder(
1235
+ input_ids=decoder_input_ids,
1236
+ attention_mask=decoder_attention_mask,
1237
+ inputs_embeds=decoder_inputs_embeds,
1238
+ past_key_values=past_key_values,
1239
+ encoder_hidden_states=hidden_states,
1240
+ encoder_attention_mask=attention_mask,
1241
+ head_mask=decoder_head_mask,
1242
+ cross_attn_head_mask=cross_attn_head_mask,
1243
+ use_cache=use_cache,
1244
+ output_attentions=output_attentions,
1245
+ output_hidden_states=output_hidden_states,
1246
+ return_dict=return_dict,
1247
+ #### NEW CODE ####
1248
+ delete_gate_mask=encoder_outputs.delete_gate_mask,
1249
+ #### NEW CODE ####
1250
+ )
1251
+
1252
+ sequence_output = decoder_outputs[0]
1253
+
1254
+ # Set device for model parallelism
1255
+ if self.model_parallel:
1256
+ torch.cuda.set_device(self.encoder.first_device)
1257
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1258
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1259
+
1260
+ if self.config.tie_word_embeddings:
1261
+ # Rescale output before projecting on vocab
1262
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1263
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1264
+
1265
+ lm_logits = self.lm_head(sequence_output)
1266
+
1267
+ loss = None
1268
+ if labels is not None:
1269
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
1270
+ # move labels to correct device to enable PP
1271
+ labels = labels.to(lm_logits.device)
1272
+ loss = loss_fct(
1273
+ lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1274
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1275
+
1276
+ if not return_dict:
1277
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1278
+ return ((loss,) + output) if loss is not None else output
1279
+
1280
+ ##### NEW CODE #####
1281
+ return MrT5Seq2SeqLMOutput(
1282
+ loss=loss,
1283
+ logits=lm_logits,
1284
+ past_key_values=decoder_outputs.past_key_values,
1285
+ decoder_hidden_states=decoder_outputs.hidden_states,
1286
+ decoder_attentions=decoder_outputs.attentions,
1287
+ cross_attentions=decoder_outputs.cross_attentions,
1288
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1289
+ encoder_hidden_states=encoder_outputs.hidden_states,
1290
+ encoder_attentions=encoder_outputs.attentions,
1291
+ delete_gate_mask=encoder_outputs.delete_gate_mask,
1292
+ delete_gate_output=encoder_outputs.delete_gate_output,
1293
+ delete_gate_logits=encoder_outputs.delete_gate_logits,
1294
+ encoder_keys=encoder_outputs.attention_keys,
1295
+ encoder_queries=encoder_outputs.attention_queries,
1296
+ encoder_values=encoder_outputs.attention_values,
1297
+ encoder_scores=encoder_outputs.attention_scores,
1298
+ decoder_keys=decoder_outputs.attention_keys,
1299
+ decoder_queries=decoder_outputs.attention_queries,
1300
+ decoder_values=decoder_outputs.attention_values,
1301
+ decoder_scores=decoder_outputs.attention_scores,
1302
+ cross_attention_queries=decoder_outputs.cross_attention_queries,
1303
+ cross_attention_keys=decoder_outputs.cross_attention_keys,
1304
+ cross_attention_values=decoder_outputs.cross_attention_values,
1305
+ cross_attention_scores=decoder_outputs.cross_attention_scores,
1306
+ )
1307
+ ##### NEW CODE #####
1308
+
1309
+ def prepare_inputs_for_generation(
1310
+ self,
1311
+ input_ids,
1312
+ past_key_values=None,
1313
+ attention_mask=None,
1314
+ head_mask=None,
1315
+ decoder_head_mask=None,
1316
+ decoder_attention_mask=None,
1317
+ cross_attn_head_mask=None,
1318
+ use_cache=None,
1319
+ encoder_outputs=None,
1320
+ **kwargs,
1321
+ ):
1322
+ # cut decoder_input_ids if past_key_values is used
1323
+ if past_key_values is not None:
1324
+ past_length = past_key_values[0][0].shape[2]
1325
+
1326
+ # Some generation methods already pass only the last input ID
1327
+ if input_ids.shape[1] > past_length:
1328
+ remove_prefix_length = past_length
1329
+ else:
1330
+ # Default to old behavior: keep only final ID
1331
+ remove_prefix_length = input_ids.shape[1] - 1
1332
+
1333
+ input_ids = input_ids[:, remove_prefix_length:]
1334
+
1335
+ ##### NEW CODE #####
1336
+ # TODO: Generation will need special handling of attention masks, which
1337
+ # will need to be resized if hard delete is enabled. For now, we will
1338
+ # simply omit the encoder attention mask for generation.
1339
+ attention_mask = None
1340
+ ##### NEW CODE #####
1341
+
1342
+ return {
1343
+ "decoder_input_ids": input_ids,
1344
+ "past_key_values": past_key_values,
1345
+ "encoder_outputs": encoder_outputs,
1346
+ "attention_mask": attention_mask,
1347
+ "head_mask": head_mask,
1348
+ "decoder_head_mask": decoder_head_mask,
1349
+ "decoder_attention_mask": decoder_attention_mask,
1350
+ "cross_attn_head_mask": cross_attn_head_mask,
1351
+ "use_cache": use_cache,
1352
+ }