Update modeling_deberta.py
Browse files- modeling_deberta.py +394 -0
modeling_deberta.py
CHANGED
@@ -1348,3 +1348,397 @@ class DebertaV2OnlyMLMHead(nn.Module):
|
|
1348 |
prediction_scores = self.predictions(sequence_output)
|
1349 |
return prediction_scores
|
1350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1348 |
prediction_scores = self.predictions(sequence_output)
|
1349 |
return prediction_scores
|
1350 |
|
1351 |
+
|
1352 |
+
@add_start_docstrings(
|
1353 |
+
"""
|
1354 |
+
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
1355 |
+
pooled output) e.g. for GLUE tasks.
|
1356 |
+
""",
|
1357 |
+
DEBERTA_START_DOCSTRING,
|
1358 |
+
)
|
1359 |
+
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
1360 |
+
def __init__(self, config):
|
1361 |
+
super().__init__(config)
|
1362 |
+
|
1363 |
+
num_labels = getattr(config, "num_labels", 2)
|
1364 |
+
self.num_labels = num_labels
|
1365 |
+
|
1366 |
+
self.deberta = DebertaV2Model(config)
|
1367 |
+
self.pooler = ContextPooler(config)
|
1368 |
+
output_dim = self.pooler.output_dim
|
1369 |
+
|
1370 |
+
self.classifier = nn.Linear(output_dim, num_labels)
|
1371 |
+
drop_out = getattr(config, "cls_dropout", None)
|
1372 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
1373 |
+
self.dropout = StableDropout(drop_out)
|
1374 |
+
|
1375 |
+
# Initialize weights and apply final processing
|
1376 |
+
self.post_init()
|
1377 |
+
|
1378 |
+
def get_input_embeddings(self):
|
1379 |
+
return self.deberta.get_input_embeddings()
|
1380 |
+
|
1381 |
+
def set_input_embeddings(self, new_embeddings):
|
1382 |
+
self.deberta.set_input_embeddings(new_embeddings)
|
1383 |
+
|
1384 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1385 |
+
@add_code_sample_docstrings(
|
1386 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1387 |
+
output_type=SequenceClassifierOutput,
|
1388 |
+
config_class=_CONFIG_FOR_DOC,
|
1389 |
+
)
|
1390 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2
|
1391 |
+
def forward(
|
1392 |
+
self,
|
1393 |
+
input_ids: Optional[torch.Tensor] = None,
|
1394 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1395 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1396 |
+
position_ids: Optional[torch.Tensor] = None,
|
1397 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1398 |
+
labels: Optional[torch.Tensor] = None,
|
1399 |
+
output_attentions: Optional[bool] = None,
|
1400 |
+
output_hidden_states: Optional[bool] = None,
|
1401 |
+
return_dict: Optional[bool] = None,
|
1402 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
1403 |
+
r"""
|
1404 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1405 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1406 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1407 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1408 |
+
"""
|
1409 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1410 |
+
|
1411 |
+
outputs = self.deberta(
|
1412 |
+
input_ids,
|
1413 |
+
token_type_ids=token_type_ids,
|
1414 |
+
attention_mask=attention_mask,
|
1415 |
+
position_ids=position_ids,
|
1416 |
+
inputs_embeds=inputs_embeds,
|
1417 |
+
output_attentions=output_attentions,
|
1418 |
+
output_hidden_states=output_hidden_states,
|
1419 |
+
return_dict=return_dict,
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
encoder_layer = outputs[0]
|
1423 |
+
pooled_output = self.pooler(encoder_layer)
|
1424 |
+
pooled_output = self.dropout(pooled_output)
|
1425 |
+
logits = self.classifier(pooled_output)
|
1426 |
+
|
1427 |
+
loss = None
|
1428 |
+
if labels is not None:
|
1429 |
+
if self.config.problem_type is None:
|
1430 |
+
if self.num_labels == 1:
|
1431 |
+
# regression task
|
1432 |
+
loss_fn = nn.MSELoss()
|
1433 |
+
logits = logits.view(-1).to(labels.dtype)
|
1434 |
+
loss = loss_fn(logits, labels.view(-1))
|
1435 |
+
elif labels.dim() == 1 or labels.size(-1) == 1:
|
1436 |
+
label_index = (labels >= 0).nonzero()
|
1437 |
+
labels = labels.long()
|
1438 |
+
if label_index.size(0) > 0:
|
1439 |
+
labeled_logits = torch.gather(
|
1440 |
+
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
|
1441 |
+
)
|
1442 |
+
labels = torch.gather(labels, 0, label_index.view(-1))
|
1443 |
+
loss_fct = CrossEntropyLoss()
|
1444 |
+
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1445 |
+
else:
|
1446 |
+
loss = torch.tensor(0).to(logits)
|
1447 |
+
else:
|
1448 |
+
log_softmax = nn.LogSoftmax(-1)
|
1449 |
+
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
|
1450 |
+
elif self.config.problem_type == "regression":
|
1451 |
+
loss_fct = MSELoss()
|
1452 |
+
if self.num_labels == 1:
|
1453 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1454 |
+
else:
|
1455 |
+
loss = loss_fct(logits, labels)
|
1456 |
+
elif self.config.problem_type == "single_label_classification":
|
1457 |
+
loss_fct = CrossEntropyLoss()
|
1458 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1459 |
+
elif self.config.problem_type == "multi_label_classification":
|
1460 |
+
loss_fct = BCEWithLogitsLoss()
|
1461 |
+
loss = loss_fct(logits, labels)
|
1462 |
+
if not return_dict:
|
1463 |
+
output = (logits,) + outputs[1:]
|
1464 |
+
return ((loss,) + output) if loss is not None else output
|
1465 |
+
|
1466 |
+
return SequenceClassifierOutput(
|
1467 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
1468 |
+
)
|
1469 |
+
|
1470 |
+
|
1471 |
+
@add_start_docstrings(
|
1472 |
+
"""
|
1473 |
+
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1474 |
+
Named-Entity-Recognition (NER) tasks.
|
1475 |
+
""",
|
1476 |
+
DEBERTA_START_DOCSTRING,
|
1477 |
+
)
|
1478 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
|
1479 |
+
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
|
1480 |
+
def __init__(self, config):
|
1481 |
+
super().__init__(config)
|
1482 |
+
self.num_labels = config.num_labels
|
1483 |
+
|
1484 |
+
self.deberta = DebertaV2Model(config)
|
1485 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1486 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1487 |
+
|
1488 |
+
# Initialize weights and apply final processing
|
1489 |
+
self.post_init()
|
1490 |
+
|
1491 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1492 |
+
@add_code_sample_docstrings(
|
1493 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1494 |
+
output_type=TokenClassifierOutput,
|
1495 |
+
config_class=_CONFIG_FOR_DOC,
|
1496 |
+
)
|
1497 |
+
def forward(
|
1498 |
+
self,
|
1499 |
+
input_ids: Optional[torch.Tensor] = None,
|
1500 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1501 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1502 |
+
position_ids: Optional[torch.Tensor] = None,
|
1503 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1504 |
+
labels: Optional[torch.Tensor] = None,
|
1505 |
+
output_attentions: Optional[bool] = None,
|
1506 |
+
output_hidden_states: Optional[bool] = None,
|
1507 |
+
return_dict: Optional[bool] = None,
|
1508 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1509 |
+
r"""
|
1510 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1511 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1512 |
+
"""
|
1513 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1514 |
+
|
1515 |
+
outputs = self.deberta(
|
1516 |
+
input_ids,
|
1517 |
+
attention_mask=attention_mask,
|
1518 |
+
token_type_ids=token_type_ids,
|
1519 |
+
position_ids=position_ids,
|
1520 |
+
inputs_embeds=inputs_embeds,
|
1521 |
+
output_attentions=output_attentions,
|
1522 |
+
output_hidden_states=output_hidden_states,
|
1523 |
+
return_dict=return_dict,
|
1524 |
+
)
|
1525 |
+
|
1526 |
+
sequence_output = outputs[0]
|
1527 |
+
|
1528 |
+
sequence_output = self.dropout(sequence_output)
|
1529 |
+
logits = self.classifier(sequence_output)
|
1530 |
+
|
1531 |
+
loss = None
|
1532 |
+
if labels is not None:
|
1533 |
+
loss_fct = CrossEntropyLoss()
|
1534 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1535 |
+
|
1536 |
+
if not return_dict:
|
1537 |
+
output = (logits,) + outputs[1:]
|
1538 |
+
return ((loss,) + output) if loss is not None else output
|
1539 |
+
|
1540 |
+
return TokenClassifierOutput(
|
1541 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
|
1545 |
+
@add_start_docstrings(
|
1546 |
+
"""
|
1547 |
+
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1548 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1549 |
+
""",
|
1550 |
+
DEBERTA_START_DOCSTRING,
|
1551 |
+
)
|
1552 |
+
class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
|
1553 |
+
def __init__(self, config):
|
1554 |
+
super().__init__(config)
|
1555 |
+
self.num_labels = config.num_labels
|
1556 |
+
|
1557 |
+
self.deberta = DebertaV2Model(config)
|
1558 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1559 |
+
|
1560 |
+
# Initialize weights and apply final processing
|
1561 |
+
self.post_init()
|
1562 |
+
|
1563 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1564 |
+
@add_code_sample_docstrings(
|
1565 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1566 |
+
output_type=QuestionAnsweringModelOutput,
|
1567 |
+
config_class=_CONFIG_FOR_DOC,
|
1568 |
+
qa_target_start_index=_QA_TARGET_START_INDEX,
|
1569 |
+
qa_target_end_index=_QA_TARGET_END_INDEX,
|
1570 |
+
)
|
1571 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2
|
1572 |
+
def forward(
|
1573 |
+
self,
|
1574 |
+
input_ids: Optional[torch.Tensor] = None,
|
1575 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1576 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1577 |
+
position_ids: Optional[torch.Tensor] = None,
|
1578 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1579 |
+
start_positions: Optional[torch.Tensor] = None,
|
1580 |
+
end_positions: Optional[torch.Tensor] = None,
|
1581 |
+
output_attentions: Optional[bool] = None,
|
1582 |
+
output_hidden_states: Optional[bool] = None,
|
1583 |
+
return_dict: Optional[bool] = None,
|
1584 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1585 |
+
r"""
|
1586 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1587 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1588 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1589 |
+
are not taken into account for computing the loss.
|
1590 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1591 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1592 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1593 |
+
are not taken into account for computing the loss.
|
1594 |
+
"""
|
1595 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1596 |
+
|
1597 |
+
outputs = self.deberta(
|
1598 |
+
input_ids,
|
1599 |
+
attention_mask=attention_mask,
|
1600 |
+
token_type_ids=token_type_ids,
|
1601 |
+
position_ids=position_ids,
|
1602 |
+
inputs_embeds=inputs_embeds,
|
1603 |
+
output_attentions=output_attentions,
|
1604 |
+
output_hidden_states=output_hidden_states,
|
1605 |
+
return_dict=return_dict,
|
1606 |
+
)
|
1607 |
+
|
1608 |
+
sequence_output = outputs[0]
|
1609 |
+
|
1610 |
+
logits = self.qa_outputs(sequence_output)
|
1611 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1612 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1613 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1614 |
+
|
1615 |
+
total_loss = None
|
1616 |
+
if start_positions is not None and end_positions is not None:
|
1617 |
+
# If we are on multi-GPU, split add a dimension
|
1618 |
+
if len(start_positions.size()) > 1:
|
1619 |
+
start_positions = start_positions.squeeze(-1)
|
1620 |
+
if len(end_positions.size()) > 1:
|
1621 |
+
end_positions = end_positions.squeeze(-1)
|
1622 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1623 |
+
ignored_index = start_logits.size(1)
|
1624 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1625 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1626 |
+
|
1627 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1628 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1629 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1630 |
+
total_loss = (start_loss + end_loss) / 2
|
1631 |
+
|
1632 |
+
if not return_dict:
|
1633 |
+
output = (start_logits, end_logits) + outputs[1:]
|
1634 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1635 |
+
|
1636 |
+
return QuestionAnsweringModelOutput(
|
1637 |
+
loss=total_loss,
|
1638 |
+
start_logits=start_logits,
|
1639 |
+
end_logits=end_logits,
|
1640 |
+
hidden_states=outputs.hidden_states,
|
1641 |
+
attentions=outputs.attentions,
|
1642 |
+
)
|
1643 |
+
|
1644 |
+
|
1645 |
+
@add_start_docstrings(
|
1646 |
+
"""
|
1647 |
+
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
1648 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
1649 |
+
""",
|
1650 |
+
DEBERTA_START_DOCSTRING,
|
1651 |
+
)
|
1652 |
+
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
1653 |
+
def __init__(self, config):
|
1654 |
+
super().__init__(config)
|
1655 |
+
|
1656 |
+
num_labels = getattr(config, "num_labels", 2)
|
1657 |
+
self.num_labels = num_labels
|
1658 |
+
|
1659 |
+
self.deberta = DebertaV2Model(config)
|
1660 |
+
self.pooler = ContextPooler(config)
|
1661 |
+
output_dim = self.pooler.output_dim
|
1662 |
+
|
1663 |
+
self.classifier = nn.Linear(output_dim, 1)
|
1664 |
+
drop_out = getattr(config, "cls_dropout", None)
|
1665 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
1666 |
+
self.dropout = StableDropout(drop_out)
|
1667 |
+
|
1668 |
+
self.init_weights()
|
1669 |
+
|
1670 |
+
def get_input_embeddings(self):
|
1671 |
+
return self.deberta.get_input_embeddings()
|
1672 |
+
|
1673 |
+
def set_input_embeddings(self, new_embeddings):
|
1674 |
+
self.deberta.set_input_embeddings(new_embeddings)
|
1675 |
+
|
1676 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1677 |
+
@add_code_sample_docstrings(
|
1678 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1679 |
+
output_type=MultipleChoiceModelOutput,
|
1680 |
+
config_class=_CONFIG_FOR_DOC,
|
1681 |
+
)
|
1682 |
+
def forward(
|
1683 |
+
self,
|
1684 |
+
input_ids: Optional[torch.Tensor] = None,
|
1685 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1686 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1687 |
+
position_ids: Optional[torch.Tensor] = None,
|
1688 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1689 |
+
labels: Optional[torch.Tensor] = None,
|
1690 |
+
output_attentions: Optional[bool] = None,
|
1691 |
+
output_hidden_states: Optional[bool] = None,
|
1692 |
+
return_dict: Optional[bool] = None,
|
1693 |
+
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
1694 |
+
r"""
|
1695 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1696 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
1697 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
1698 |
+
`input_ids` above)
|
1699 |
+
"""
|
1700 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1701 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
1702 |
+
|
1703 |
+
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
1704 |
+
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
1705 |
+
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
1706 |
+
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
1707 |
+
flat_inputs_embeds = (
|
1708 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
1709 |
+
if inputs_embeds is not None
|
1710 |
+
else None
|
1711 |
+
)
|
1712 |
+
|
1713 |
+
outputs = self.deberta(
|
1714 |
+
flat_input_ids,
|
1715 |
+
position_ids=flat_position_ids,
|
1716 |
+
token_type_ids=flat_token_type_ids,
|
1717 |
+
attention_mask=flat_attention_mask,
|
1718 |
+
inputs_embeds=flat_inputs_embeds,
|
1719 |
+
output_attentions=output_attentions,
|
1720 |
+
output_hidden_states=output_hidden_states,
|
1721 |
+
return_dict=return_dict,
|
1722 |
+
)
|
1723 |
+
|
1724 |
+
encoder_layer = outputs[0]
|
1725 |
+
pooled_output = self.pooler(encoder_layer)
|
1726 |
+
pooled_output = self.dropout(pooled_output)
|
1727 |
+
logits = self.classifier(pooled_output)
|
1728 |
+
reshaped_logits = logits.view(-1, num_choices)
|
1729 |
+
|
1730 |
+
loss = None
|
1731 |
+
if labels is not None:
|
1732 |
+
loss_fct = CrossEntropyLoss()
|
1733 |
+
loss = loss_fct(reshaped_logits, labels)
|
1734 |
+
|
1735 |
+
if not return_dict:
|
1736 |
+
output = (reshaped_logits,) + outputs[1:]
|
1737 |
+
return ((loss,) + output) if loss is not None else output
|
1738 |
+
|
1739 |
+
return MultipleChoiceModelOutput(
|
1740 |
+
loss=loss,
|
1741 |
+
logits=reshaped_logits,
|
1742 |
+
hidden_states=outputs.hidden_states,
|
1743 |
+
attentions=outputs.attentions,
|
1744 |
+
)
|