from typing import Any, List import torch import torch.nn as nn from BERT_explainability.modules.BERT.BERT import BertModel from BERT_explainability.modules.layers_ours import * from BERT_rationale_benchmark.models.model_utils import PaddedSequence from torch.nn import CrossEntropyLoss, MSELoss from transformers import BertPreTrainedModel from transformers.utils import logging class BertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.dropout = Dropout(config.hidden_dropout_prob) self.classifier = Linear(config.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def relprop(self, cam=None, **kwargs): cam = self.classifier.relprop(cam, **kwargs) cam = self.dropout.relprop(cam, **kwargs) cam = self.bert.relprop(cam, **kwargs) # print("conservation: ", cam.sum()) return cam # this is the actual classifier we will be using class BertClassifier(nn.Module): """Thin wrapper around BertForSequenceClassification""" def __init__( self, bert_dir: str, pad_token_id: int, cls_token_id: int, sep_token_id: int, num_labels: int, max_length: int = 512, use_half_precision=True, ): super(BertClassifier, self).__init__() bert = BertForSequenceClassification.from_pretrained( bert_dir, num_labels=num_labels ) if use_half_precision: import apex bert = bert.half() self.bert = bert self.pad_token_id = pad_token_id self.cls_token_id = cls_token_id self.sep_token_id = sep_token_id self.max_length = max_length def forward( self, query: List[torch.tensor], docids: List[Any], document_batch: List[torch.tensor], ): assert len(query) == len(document_batch) print(query) # note about device management: # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access target_device = next(self.parameters()).device cls_token = torch.tensor([self.cls_token_id]).to( device=document_batch[0].device ) sep_token = torch.tensor([self.sep_token_id]).to( device=document_batch[0].device ) input_tensors = [] position_ids = [] for q, d in zip(query, document_batch): if len(q) + len(d) + 2 > self.max_length: d = d[: (self.max_length - len(q) - 2)] input_tensors.append(torch.cat([cls_token, q, sep_token, d])) position_ids.append( torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))) ) bert_input = PaddedSequence.autopad( input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device, ) positions = PaddedSequence.autopad( position_ids, batch_first=True, padding_value=0, device=target_device ) (classes,) = self.bert( bert_input.data, attention_mask=bert_input.mask( on=0.0, off=float("-inf"), device=target_device ), position_ids=positions.data, ) assert torch.all(classes == classes) # for nans print(input_tensors[0]) print(self.relprop()[0]) return classes def relprop(self, cam=None, **kwargs): return self.bert.relprop(cam, **kwargs) if __name__ == "__main__": import os from transformers import BertTokenizer class Config: def __init__( self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels, hidden_dropout_prob, ): self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.attention_probs_dropout_prob = attention_probs_dropout_prob self.num_labels = num_labels self.hidden_dropout_prob = hidden_dropout_prob tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") x = tokenizer.encode_plus( "In this movie the acting is great. The movie is perfect! [sep]", add_special_tokens=True, max_length=512, return_token_type_ids=False, return_attention_mask=True, pad_to_max_length=True, return_tensors="pt", truncation=True, ) print(x["input_ids"]) model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=2 ) model_save_file = os.path.join( "./BERT_explainability/output_bert/movies/classifier/", "classifier.pt" ) model.load_state_dict(torch.load(model_save_file)) # x = torch.randint(100, (2, 20)) # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102, # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101, # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005, # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102, # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101, # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010, # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102, # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101, # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054, # 102, 101, 1012, 102]]) # x.requires_grad_() model.eval() y = model(x["input_ids"], x["attention_mask"]) print(y) cam, _ = model.relprop() # print(cam.shape) cam = cam.sum(-1) # print(cam)