stallone commited on
Commit
b6e04ba
·
1 Parent(s): 3872a50

Support for PT2C

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +1 -1
modeling_mpt.py CHANGED
@@ -152,7 +152,7 @@ class MPTModel(MPTPreTrainedModel):
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
+ if self.training and attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0]:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')