Fix Flash with batch inputs
Browse files- modeling_eurobert.py +1 -1
modeling_eurobert.py
CHANGED
@@ -526,7 +526,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
|
|
526 |
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
|
527 |
mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
|
528 |
else:
|
529 |
-
mask =
|
530 |
|
531 |
hidden_states = inputs_embeds
|
532 |
|
|
|
526 |
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
|
527 |
mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
|
528 |
else:
|
529 |
+
mask = attention_mask
|
530 |
|
531 |
hidden_states = inputs_embeds
|
532 |
|