michael-guenther commited on
Commit
681845d
·
1 Parent(s): 54b019f

add option to output hidden states

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta.py +43 -10
modeling_xlm_roberta.py CHANGED
@@ -22,13 +22,13 @@ import torch.nn.functional as F
22
  import torch.utils.checkpoint
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from transformers import AutoTokenizer, PretrainedConfig
25
- from transformers.modeling_outputs import (MaskedLMOutput,
26
- SequenceClassifierOutput)
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.models.bert.modeling_bert import (
29
- BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
30
- from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
- XLMRobertaLMHead
 
32
 
33
  from .rotary import RotaryEmbedding
34
  from .block import Block
@@ -195,17 +195,30 @@ class XLMRobertaEncoder(nn.Module):
195
  self._grad_checkpointing = value
196
 
197
  def forward(
198
- self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None
 
 
 
 
 
199
  ):
200
  """If subset_mask is not None, we only want output for the subset of the sequence.
201
  This means that we only compute the last layer output for these tokens.
202
  subset_mask: (batch, seqlen), dtype=torch.bool
203
  """
 
 
 
 
 
 
204
  if key_padding_mask is None or not self.use_flash_attn:
205
  mixer_kwargs = {"adapter_mask": adapter_mask}
206
  if key_padding_mask is not None:
207
  mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
208
  for layer in self.layers:
 
 
209
  if self._grad_checkpointing:
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
@@ -215,10 +228,14 @@ class XLMRobertaEncoder(nn.Module):
215
  )
216
  else:
217
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
218
  if subset_mask is not None:
219
  hidden_states = hidden_states[subset_mask]
220
  else:
221
  batch, seqlen = hidden_states.shape[:2]
 
 
222
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
223
  unpad_input(hidden_states, key_padding_mask, adapter_mask)
224
  )
@@ -239,6 +256,10 @@ class XLMRobertaEncoder(nn.Module):
239
  )
240
  else:
241
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
242
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
243
  else:
244
  for layer in self.layers[:-1]:
@@ -291,7 +312,7 @@ class XLMRobertaEncoder(nn.Module):
291
  hidden_states = self.layers[-1](
292
  hidden_states_subset, mixer_kwargs=mixer_kwargs
293
  )
294
- return hidden_states
295
 
296
 
297
  class XLMRobertaPooler(nn.Module):
@@ -588,7 +609,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
588
  embeddings = self.mean_pooling(
589
  token_embs, encoded_input["attention_mask"]
590
  )
591
-
592
  all_embeddings.extend(embeddings)
593
 
594
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
@@ -596,9 +617,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
596
  truncate_dim = truncate_dim or self.config.truncate_dim
597
  if truncate_dim:
598
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
599
-
600
  if normalize_embeddings:
601
- all_embeddings = [torch.nn.functional.normalize(embedding, p=2, dim=0) for embedding in all_embeddings]
 
 
 
602
 
603
  if convert_to_tensor:
604
  all_embeddings = torch.stack(all_embeddings)
@@ -659,6 +683,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
659
  attention_mask=None,
660
  masked_tokens_mask=None,
661
  return_dict=None,
 
662
  **kwargs,
663
  ):
664
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
@@ -711,8 +736,15 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
711
  key_padding_mask=attention_mask,
712
  subset_mask=subset_mask,
713
  adapter_mask=adapter_mask,
 
714
  )
715
 
 
 
 
 
 
 
716
  if masked_tokens_mask is None:
717
  pooled_output = (
718
  self.pooler(sequence_output, adapter_mask=adapter_mask)
@@ -742,6 +774,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
742
  return BaseModelOutputWithPoolingAndCrossAttentions(
743
  last_hidden_state=sequence_output,
744
  pooler_output=pooled_output,
 
745
  )
746
 
747
 
 
22
  import torch.utils.checkpoint
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from transformers import AutoTokenizer, PretrainedConfig
25
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
 
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.models.bert.modeling_bert import (
28
+ BaseModelOutputWithPoolingAndCrossAttentions,
29
+ BertForPreTrainingOutput,
30
+ )
31
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
32
 
33
  from .rotary import RotaryEmbedding
34
  from .block import Block
 
195
  self._grad_checkpointing = value
196
 
197
  def forward(
198
+ self,
199
+ hidden_states,
200
+ key_padding_mask=None,
201
+ subset_mask=None,
202
+ adapter_mask=None,
203
+ output_hidden_states: Optional[bool] = None,
204
  ):
205
  """If subset_mask is not None, we only want output for the subset of the sequence.
206
  This means that we only compute the last layer output for these tokens.
207
  subset_mask: (batch, seqlen), dtype=torch.bool
208
  """
209
+
210
+ all_hidden_states = () if output_hidden_states else None
211
+
212
+ if output_hidden_states and subset_mask:
213
+ raise ValueError('output_hidden_states is not supported for subset_masks')
214
+
215
  if key_padding_mask is None or not self.use_flash_attn:
216
  mixer_kwargs = {"adapter_mask": adapter_mask}
217
  if key_padding_mask is not None:
218
  mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
219
  for layer in self.layers:
220
+ if output_hidden_states:
221
+ all_hidden_states = all_hidden_states + (hidden_states,)
222
  if self._grad_checkpointing:
223
  hidden_states = torch.utils.checkpoint.checkpoint(
224
  layer,
 
228
  )
229
  else:
230
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
231
+ if output_hidden_states:
232
+ all_hidden_states = all_hidden_states + (hidden_states,)
233
  if subset_mask is not None:
234
  hidden_states = hidden_states[subset_mask]
235
  else:
236
  batch, seqlen = hidden_states.shape[:2]
237
+ if output_hidden_states:
238
+ all_hidden_states = all_hidden_states + (hidden_states,)
239
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
240
  unpad_input(hidden_states, key_padding_mask, adapter_mask)
241
  )
 
256
  )
257
  else:
258
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
259
+ if output_hidden_states:
260
+ all_hidden_states = all_hidden_states + (
261
+ pad_input(hidden_states, indices, batch, seqlen),
262
+ )
263
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
264
  else:
265
  for layer in self.layers[:-1]:
 
312
  hidden_states = self.layers[-1](
313
  hidden_states_subset, mixer_kwargs=mixer_kwargs
314
  )
315
+ return all_hidden_states if output_hidden_states else hidden_states
316
 
317
 
318
  class XLMRobertaPooler(nn.Module):
 
609
  embeddings = self.mean_pooling(
610
  token_embs, encoded_input["attention_mask"]
611
  )
612
+
613
  all_embeddings.extend(embeddings)
614
 
615
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
 
617
  truncate_dim = truncate_dim or self.config.truncate_dim
618
  if truncate_dim:
619
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
620
+
621
  if normalize_embeddings:
622
+ all_embeddings = [
623
+ torch.nn.functional.normalize(embedding, p=2, dim=0)
624
+ for embedding in all_embeddings
625
+ ]
626
 
627
  if convert_to_tensor:
628
  all_embeddings = torch.stack(all_embeddings)
 
683
  attention_mask=None,
684
  masked_tokens_mask=None,
685
  return_dict=None,
686
+ output_hidden_states=None,
687
  **kwargs,
688
  ):
689
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
 
736
  key_padding_mask=attention_mask,
737
  subset_mask=subset_mask,
738
  adapter_mask=adapter_mask,
739
+ output_hidden_states=output_hidden_states,
740
  )
741
 
742
+ if output_hidden_states:
743
+ all_hidden_states = sequence_output
744
+ sequence_output = sequence_output[-1]
745
+ else:
746
+ all_hidden_states = None
747
+
748
  if masked_tokens_mask is None:
749
  pooled_output = (
750
  self.pooler(sequence_output, adapter_mask=adapter_mask)
 
774
  return BaseModelOutputWithPoolingAndCrossAttentions(
775
  last_hidden_state=sequence_output,
776
  pooler_output=pooled_output,
777
+ hidden_states=all_hidden_states,
778
  )
779
 
780