Update modeling_deberta.py
Browse files- modeling_deberta.py +18 -14
modeling_deberta.py
CHANGED
@@ -1251,20 +1251,24 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1251 |
],
|
1252 |
dim=-1
|
1253 |
)
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
1264 |
-
|
1265 |
-
|
1266 |
-
|
1267 |
-
|
|
|
|
|
|
|
|
|
1268 |
|
1269 |
outputs = super().forward(
|
1270 |
input_ids,
|
|
|
1251 |
],
|
1252 |
dim=-1
|
1253 |
)
|
1254 |
+
|
1255 |
+
if attention_mask is not None:
|
1256 |
+
attention_mask = torch.cat(
|
1257 |
+
[
|
1258 |
+
attention_mask,
|
1259 |
+
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
1260 |
+
],
|
1261 |
+
dim=-1
|
1262 |
+
)
|
1263 |
+
|
1264 |
+
if position_ids is not None:
|
1265 |
+
position_ids = torch.cat(
|
1266 |
+
[
|
1267 |
+
position_ids,
|
1268 |
+
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
|
1269 |
+
],
|
1270 |
+
dim=-1
|
1271 |
+
)
|
1272 |
|
1273 |
outputs = super().forward(
|
1274 |
input_ids,
|