File size: 3,647 Bytes
9d1fa85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization
from util import visualize_text
classifications = ["NEGATIVE", "POSITIVE"]
class IntegratedGradientsExplainer:
def __init__(self):
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
self.ref_token_id = self.tokenizer.unk_token_id
def tokens_from_ids(self, ids):
return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))
def custom_forward(self, inputs, attention_mask=None, pos=0):
result = self.model(inputs, attention_mask=attention_mask, return_dict=True)
preds = result.logits
return preds
@staticmethod
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
try:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
if index is None:
index = output.argmax(axis=-1).item()
ablator = LayerIntegratedGradients(self.custom_forward, layer)
input_tensor = input_ids
attention_mask = attention_mask
attributions = ablator.attribute(
inputs=input_ids,
baselines=self.ref_token_id,
additional_forward_args=(attention_mask),
target=index,
n_steps=steps,
)
return self.summarize_attributions(attributions).unsqueeze_(0), output, index
finally:
pass
def build_visualization(self, input_ids, attention_mask, **kwargs):
vis_data_records = []
attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs)
for record in range(input_ids.size(0)):
classification = output[record].argmax(dim=-1).item()
class_name = classifications[classification]
attr = attributions[record]
tokens = self.tokens_from_ids(input_ids[record].flatten())[
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
]
vis_data_records.append(
visualization.VisualizationDataRecord(
attr,
output[record][classification],
classification,
classification,
index,
1,
tokens,
1,
)
)
return visualize_text(vis_data_records)
def __call__(self, input_text, layer):
text_batch = [input_text]
encoding = self.tokenizer(text_batch, return_tensors="pt")
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
layer = int(layer)
if layer == 0:
layer = self.model.roberta.embeddings
else:
layer = getattr(self.model.roberta.encoder.layer, str(layer-1))
return self.build_visualization(input_ids, attention_mask, layer=layer)
|