Update modeling_ltgbert.py
Browse files- modeling_ltgbert.py +4 -1
modeling_ltgbert.py
CHANGED
@@ -346,7 +346,10 @@ class LtgbertModel(LtgbertPreTrainedModel):
|
|
346 |
if self.config.is_decoder:
|
347 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
|
348 |
else:
|
349 |
-
|
|
|
|
|
|
|
350 |
|
351 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
352 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
|
|
346 |
if self.config.is_decoder:
|
347 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
|
348 |
else:
|
349 |
+
if len(attention_mask.size()) == 2:
|
350 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
351 |
+
elif len(attention_mask.size()) == 3:
|
352 |
+
attention_mask = attention_mask.unsqueeze(1)
|
353 |
|
354 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
355 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|