ltg
/

lgcharpe commited on
Commit
acf5fc7
·
verified ·
1 Parent(s): 691a253

Update modeling_ltgbert.py

Browse files
Files changed (1) hide show
  1. modeling_ltgbert.py +4 -1
modeling_ltgbert.py CHANGED
@@ -346,7 +346,10 @@ class LtgbertModel(LtgbertPreTrainedModel):
346
  if self.config.is_decoder:
347
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
348
  else:
349
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
 
 
 
350
 
351
  static_embeddings, relative_embedding = self.embedding(input_ids.t())
352
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
 
346
  if self.config.is_decoder:
347
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
348
  else:
349
+ if len(attention_mask.size()) == 2:
350
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
351
+ elif len(attention_mask.size()) == 3:
352
+ attention_mask = attention_mask.unsqueeze(1)
353
 
354
  static_embeddings, relative_embedding = self.embedding(input_ids.t())
355
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)