ltg
/

davda54 commited on
Commit
3586af3
1 Parent(s): 803f977

Update modeling_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_deberta.py +18 -14
modeling_deberta.py CHANGED
@@ -1251,20 +1251,24 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1251
  ],
1252
  dim=-1
1253
  )
1254
- attention_mask = torch.cat(
1255
- [
1256
- attention_mask,
1257
- torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
1258
- ],
1259
- dim=-1
1260
- )
1261
- position_ids = torch.cat(
1262
- [
1263
- position_ids,
1264
- torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
1265
- ],
1266
- dim=-1
1267
- )
 
 
 
 
1268
 
1269
  outputs = super().forward(
1270
  input_ids,
 
1251
  ],
1252
  dim=-1
1253
  )
1254
+
1255
+ if attention_mask is not None:
1256
+ attention_mask = torch.cat(
1257
+ [
1258
+ attention_mask,
1259
+ torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
1260
+ ],
1261
+ dim=-1
1262
+ )
1263
+
1264
+ if position_ids is not None:
1265
+ position_ids = torch.cat(
1266
+ [
1267
+ position_ids,
1268
+ torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
1269
+ ],
1270
+ dim=-1
1271
+ )
1272
 
1273
  outputs = super().forward(
1274
  input_ids,