Support for PT2C
Browse files- modeling_mpt.py +1 -1
modeling_mpt.py
CHANGED
@@ -152,7 +152,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
152 |
if output_attentions:
|
153 |
if self.attn_impl != 'torch':
|
154 |
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
|
155 |
-
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
156 |
raise NotImplementedError('MPT does not support training with left padding.')
|
157 |
if self.prefix_lm and prefix_mask is None:
|
158 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
|
|
152 |
if output_attentions:
|
153 |
if self.attn_impl != 'torch':
|
154 |
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
|
155 |
+
if self.training and attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0]:
|
156 |
raise NotImplementedError('MPT does not support training with left padding.')
|
157 |
if self.prefix_lm and prefix_mask is None:
|
158 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|