mohdelgaar commited on
Commit
f727cfc
·
1 Parent(s): 56ce0db

upload file

Browse files
Files changed (1) hide show
  1. lingconv_t5.py +453 -0
lingconv_t5.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import copy
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ Seq2SeqLMOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ )
13
+ from transformers.models.t5.modeling_t5 import T5Stack, T5ForConditionalGeneration, __HEAD_MASK_WARNING_MSG
14
+ from transformers import T5Config
15
+
16
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
17
+ from transformers.utils import (
18
+ is_torchdynamo_compiling,
19
+ )
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class LingConvT5Stack(T5Stack):
25
+ def __init__(self, config: T5Config, embed_tokens=None):
26
+ super().__init__(config, embed_tokens)
27
+
28
+ # Add new attributes for ling injection
29
+ self.ling_injection_layer = getattr(config, 'ling_injection_layer', -1)
30
+ self.ling_injection_type = getattr(config, 'ling_injection_type', 'none') # 'none', 'first', 'all'
31
+
32
+ def forward(
33
+ self,
34
+ input_ids=None,
35
+ attention_mask=None,
36
+ encoder_hidden_states=None,
37
+ encoder_attention_mask=None,
38
+ inputs_embeds=None,
39
+ head_mask=None,
40
+ cross_attn_head_mask=None,
41
+ past_key_values=None,
42
+ use_cache=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ return_dict=None,
46
+ cache_position=None,
47
+ ling_embed=None,
48
+ ):
49
+ # Model parallel
50
+ if self.model_parallel:
51
+ torch.cuda.set_device(self.first_device)
52
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
53
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
54
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
55
+ output_hidden_states = (
56
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+
60
+ if input_ids is not None and inputs_embeds is not None:
61
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
62
+ raise ValueError(
63
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
64
+ )
65
+ elif input_ids is not None:
66
+ input_shape = input_ids.size()
67
+ input_ids = input_ids.view(-1, input_shape[-1])
68
+ elif inputs_embeds is not None:
69
+ input_shape = inputs_embeds.size()[:-1]
70
+ else:
71
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
72
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
73
+
74
+ if self.gradient_checkpointing and self.training:
75
+ if use_cache:
76
+ logger.warning_once(
77
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
78
+ )
79
+ use_cache = False
80
+
81
+ if inputs_embeds is None:
82
+ if self.embed_tokens is None:
83
+ raise ValueError("You have to initialize the model with valid token embeddings")
84
+ inputs_embeds = self.embed_tokens(input_ids)
85
+
86
+ batch_size, seq_length = input_shape
87
+
88
+ if use_cache is True:
89
+ if not self.is_decoder:
90
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
91
+
92
+ # initialize past_key_values
93
+ return_legacy_cache = False
94
+ return_self_attention_cache = False
95
+ if self.is_decoder and (use_cache or past_key_values is not None):
96
+ if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
97
+ return_self_attention_cache = True
98
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
99
+ elif not isinstance(past_key_values, EncoderDecoderCache):
100
+ return_legacy_cache = True
101
+ logger.warning_once(
102
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
103
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
104
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
105
+ )
106
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
107
+ elif past_key_values is None:
108
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
109
+ elif not self.is_decoder:
110
+ # do not pass cache object down the line for encoder stack
111
+ # it messes indexing later in decoder-stack because cache object is modified in-place
112
+ past_key_values = None
113
+
114
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
115
+ if cache_position is None:
116
+ cache_position = torch.arange(
117
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
118
+ )
119
+
120
+ if attention_mask is None and not is_torchdynamo_compiling():
121
+ # required mask seq length can be calculated via length of past cache
122
+ mask_seq_length = past_key_values_length + seq_length
123
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
124
+
125
+ if self.config.is_decoder:
126
+ causal_mask = self._update_causal_mask(
127
+ attention_mask,
128
+ inputs_embeds,
129
+ cache_position,
130
+ past_key_values.self_attention_cache if past_key_values is not None else None,
131
+ output_attentions,
132
+ )
133
+ elif attention_mask is not None:
134
+ causal_mask = attention_mask[:, None, None, :]
135
+ causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
136
+ causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
137
+ else:
138
+ causal_mask = None
139
+
140
+ # If a 2D or 3D attention mask is provided for the cross-attention
141
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
142
+ if self.is_decoder and encoder_hidden_states is not None:
143
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
144
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
145
+ if encoder_attention_mask is None:
146
+ encoder_attention_mask = torch.ones(
147
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
148
+ )
149
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
150
+ else:
151
+ encoder_extended_attention_mask = None
152
+
153
+ # Prepare head mask if needed
154
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
155
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
156
+ all_hidden_states = () if output_hidden_states else None
157
+ all_attentions = () if output_attentions else None
158
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
159
+ position_bias = None
160
+ encoder_decoder_position_bias = None
161
+
162
+ hidden_states = self.dropout(inputs_embeds)
163
+
164
+ for i, layer_module in enumerate(self.block):
165
+ layer_head_mask = head_mask[i]
166
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
167
+ # Model parallel
168
+ if self.model_parallel:
169
+ torch.cuda.set_device(hidden_states.device)
170
+ # Ensure that attention_mask is always on the same device as hidden_states
171
+ if causal_mask is not None:
172
+ causal_mask = causal_mask.to(hidden_states.device)
173
+ if position_bias is not None:
174
+ position_bias = position_bias.to(hidden_states.device)
175
+ if encoder_hidden_states is not None:
176
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
177
+ if encoder_extended_attention_mask is not None:
178
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
179
+ if encoder_decoder_position_bias is not None:
180
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
181
+ if layer_head_mask is not None:
182
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
183
+ if cross_attn_layer_head_mask is not None:
184
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
185
+ if output_hidden_states:
186
+ all_hidden_states = all_hidden_states + (hidden_states,)
187
+
188
+ if self.gradient_checkpointing and self.training:
189
+ layer_outputs = self._gradient_checkpointing_func(
190
+ layer_module.forward,
191
+ hidden_states,
192
+ causal_mask,
193
+ position_bias,
194
+ encoder_hidden_states,
195
+ encoder_extended_attention_mask,
196
+ encoder_decoder_position_bias,
197
+ layer_head_mask,
198
+ cross_attn_layer_head_mask,
199
+ None, # past_key_value is always None with gradient checkpointing
200
+ use_cache,
201
+ output_attentions,
202
+ return_dict,
203
+ cache_position,
204
+ )
205
+ else:
206
+ layer_outputs = layer_module(
207
+ hidden_states,
208
+ attention_mask=causal_mask,
209
+ position_bias=position_bias,
210
+ encoder_hidden_states=encoder_hidden_states,
211
+ encoder_attention_mask=encoder_extended_attention_mask,
212
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
213
+ layer_head_mask=layer_head_mask,
214
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
215
+ past_key_value=past_key_values,
216
+ use_cache=use_cache,
217
+ output_attentions=output_attentions,
218
+ return_dict=return_dict,
219
+ cache_position=cache_position,
220
+ )
221
+
222
+ # layer_outputs is a tuple with:
223
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
224
+ if use_cache is False:
225
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
226
+
227
+ hidden_states, next_decoder_cache = layer_outputs[:2]
228
+
229
+ # Add linguistic embedding injection after specified layer
230
+ if (self.is_decoder and
231
+ self.ling_injection_layer == i and
232
+ ling_embed is not None and
233
+ self.ling_injection_type != 'none'):
234
+
235
+ hidden_states = hidden_states + ling_embed
236
+
237
+ # We share the position biases between the layers - the first layer store them
238
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
239
+ # (cross-attention position bias), (cross-attention weights)
240
+ position_bias = layer_outputs[2]
241
+ if self.is_decoder and encoder_hidden_states is not None:
242
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
243
+
244
+ if output_attentions:
245
+ all_attentions = all_attentions + (layer_outputs[3],)
246
+ if self.is_decoder:
247
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
248
+
249
+ # Model Parallel: If it's the last layer for that device, put things on the next device
250
+ if self.model_parallel:
251
+ for k, v in self.device_map.items():
252
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
253
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
254
+
255
+ hidden_states = self.final_layer_norm(hidden_states)
256
+ hidden_states = self.dropout(hidden_states)
257
+
258
+ # Add last layer
259
+ if output_hidden_states:
260
+ all_hidden_states = all_hidden_states + (hidden_states,)
261
+
262
+ next_cache = next_decoder_cache if use_cache else None
263
+ if return_self_attention_cache:
264
+ next_cache = past_key_values.self_attention_cache
265
+ if return_legacy_cache:
266
+ next_cache = past_key_values.to_legacy_cache()
267
+
268
+ if not return_dict:
269
+ return tuple(
270
+ v
271
+ for v in [
272
+ hidden_states,
273
+ next_cache,
274
+ all_hidden_states,
275
+ all_attentions,
276
+ all_cross_attentions,
277
+ ]
278
+ if v is not None
279
+ )
280
+ return BaseModelOutputWithPastAndCrossAttentions(
281
+ last_hidden_state=hidden_states,
282
+ past_key_values=next_cache,
283
+ hidden_states=all_hidden_states,
284
+ attentions=all_attentions,
285
+ cross_attentions=all_cross_attentions,
286
+ )
287
+
288
+ class LingConvT5ForConditionalGeneration(T5ForConditionalGeneration):
289
+ def __init__(self, config):
290
+ super().__init__(config)
291
+ # Replace default decoder with our custom decoder
292
+ decoder_config = copy.deepcopy(config)
293
+ decoder_config.is_decoder = True
294
+ decoder_config.is_encoder_decoder = False
295
+ decoder_config.num_layers = config.num_decoder_layers
296
+ self.decoder = LingConvT5Stack(decoder_config, embed_tokens=self.shared)
297
+
298
+ def forward(
299
+ self,
300
+ input_ids: Optional[torch.LongTensor] = None,
301
+ attention_mask: Optional[torch.FloatTensor] = None,
302
+ decoder_input_ids: Optional[torch.LongTensor] = None,
303
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
304
+ head_mask: Optional[torch.FloatTensor] = None,
305
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
306
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
307
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
308
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
309
+ inputs_embeds: Optional[torch.FloatTensor] = None,
310
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
311
+ labels: Optional[torch.LongTensor] = None,
312
+ use_cache: Optional[bool] = None,
313
+ output_attentions: Optional[bool] = None,
314
+ output_hidden_states: Optional[bool] = None,
315
+ return_dict: Optional[bool] = None,
316
+ cache_position: Optional[torch.LongTensor] = None,
317
+ ling_embed: Optional[torch.FloatTensor] = None,
318
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
319
+ r"""
320
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
321
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
322
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
323
+ labels in `[0, ..., config.vocab_size]`
324
+
325
+ Returns:
326
+
327
+ Examples:
328
+
329
+ ```python
330
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
331
+
332
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
333
+ >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
334
+
335
+ >>> # training
336
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
337
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
338
+ >>> outputs = model(input_ids=input_ids, labels=labels)
339
+ >>> loss = outputs.loss
340
+ >>> logits = outputs.logits
341
+
342
+ >>> # inference
343
+ >>> input_ids = tokenizer(
344
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
345
+ ... ).input_ids # Batch size 1
346
+ >>> outputs = model.generate(input_ids)
347
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
348
+ >>> # studies have shown that owning a dog is good for you.
349
+ ```"""
350
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+
353
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
354
+ if head_mask is not None and decoder_head_mask is None:
355
+ if self.config.num_layers == self.config.num_decoder_layers:
356
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
357
+ decoder_head_mask = head_mask
358
+
359
+ # Encode if needed (training, first prediction pass)
360
+ if encoder_outputs is None:
361
+ # Convert encoder inputs in embeddings if needed
362
+ encoder_outputs = self.encoder(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ inputs_embeds=inputs_embeds,
366
+ head_mask=head_mask,
367
+ output_attentions=output_attentions,
368
+ output_hidden_states=output_hidden_states,
369
+ return_dict=return_dict,
370
+ )
371
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
372
+ encoder_outputs = BaseModelOutput(
373
+ last_hidden_state=encoder_outputs[0],
374
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
375
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
376
+ )
377
+
378
+ hidden_states = encoder_outputs[0]
379
+
380
+ if self.model_parallel:
381
+ torch.cuda.set_device(self.decoder.first_device)
382
+
383
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
384
+ # get decoder inputs from shifting lm labels to the right
385
+ decoder_input_ids = self._shift_right(labels)
386
+
387
+ # Set device for model parallelism
388
+ if self.model_parallel:
389
+ torch.cuda.set_device(self.decoder.first_device)
390
+ hidden_states = hidden_states.to(self.decoder.first_device)
391
+ if decoder_input_ids is not None:
392
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
393
+ if attention_mask is not None:
394
+ attention_mask = attention_mask.to(self.decoder.first_device)
395
+ if decoder_attention_mask is not None:
396
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
397
+
398
+ # Decode
399
+ decoder_outputs = self.decoder(
400
+ input_ids=decoder_input_ids,
401
+ attention_mask=decoder_attention_mask,
402
+ inputs_embeds=decoder_inputs_embeds,
403
+ past_key_values=past_key_values,
404
+ encoder_hidden_states=hidden_states,
405
+ encoder_attention_mask=attention_mask,
406
+ head_mask=decoder_head_mask,
407
+ cross_attn_head_mask=cross_attn_head_mask,
408
+ use_cache=use_cache,
409
+ output_attentions=output_attentions,
410
+ output_hidden_states=output_hidden_states,
411
+ return_dict=return_dict,
412
+ cache_position=cache_position,
413
+ ling_embed=ling_embed,
414
+ )
415
+
416
+ sequence_output = decoder_outputs[0]
417
+
418
+ # Set device for model parallelism
419
+ if self.model_parallel:
420
+ torch.cuda.set_device(self.encoder.first_device)
421
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
422
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
423
+
424
+ if self.config.tie_word_embeddings:
425
+ # Rescale output before projecting on vocab
426
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
427
+ sequence_output = sequence_output * (self.model_dim**-0.5)
428
+
429
+ lm_logits = self.lm_head(sequence_output)
430
+
431
+ loss = None
432
+ if labels is not None:
433
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
434
+ # move labels to correct device to enable PP
435
+ labels = labels.to(lm_logits.device)
436
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
437
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
438
+
439
+ if not return_dict:
440
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
441
+ return ((loss,) + output) if loss is not None else output
442
+
443
+ return Seq2SeqLMOutput(
444
+ loss=loss,
445
+ logits=lm_logits,
446
+ past_key_values=decoder_outputs.past_key_values,
447
+ decoder_hidden_states=decoder_outputs.hidden_states,
448
+ decoder_attentions=decoder_outputs.attentions,
449
+ cross_attentions=decoder_outputs.cross_attentions,
450
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
451
+ encoder_hidden_states=encoder_outputs.hidden_states,
452
+ encoder_attentions=encoder_outputs.attentions,
453
+ )