|
""" |
|
Original work: |
|
https://github.com/sangHa0411/CloneDetection/blob/main/models/codebert.py#L169 |
|
|
|
Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) |
|
|
|
All credits to the original authors. |
|
""" |
|
import torch.nn as nn |
|
from transformers import ( |
|
RobertaPreTrainedModel, |
|
RobertaModel, |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
class CloneDetectionModel(RobertaPreTrainedModel): |
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
self.net = nn.Sequential( |
|
nn.Dropout(config.hidden_dropout_prob), |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
nn.ReLU(), |
|
) |
|
self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
batch_size, _, hidden_size = hidden_states.shape |
|
|
|
|
|
cls_flag = input_ids == self.config.tokenizer_cls_token_id |
|
sep_flag = input_ids == self.config.tokenizer_sep_token_id |
|
|
|
special_token_states = hidden_states[cls_flag + sep_flag].view( |
|
batch_size, -1, hidden_size |
|
) |
|
special_hidden_states = self.net( |
|
special_token_states |
|
) |
|
|
|
pooled_output = special_hidden_states.view( |
|
batch_size, -1 |
|
) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|