File size: 3,647 Bytes
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch

from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

from captum.attr import LayerIntegratedGradients
from captum.attr import visualization

from util import visualize_text

classifications = ["NEGATIVE", "POSITIVE"]

class IntegratedGradientsExplainer:
    def __init__(self):
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
        self.ref_token_id = self.tokenizer.unk_token_id

    def tokens_from_ids(self, ids):
        return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))

    def custom_forward(self, inputs, attention_mask=None, pos=0):
        result = self.model(inputs, attention_mask=attention_mask, return_dict=True)
        preds = result.logits
        return preds

    @staticmethod
    def summarize_attributions(attributions):
        attributions = attributions.sum(dim=-1).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        return attributions


    def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
        try:
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
            if index is None:
                index = output.argmax(axis=-1).item()

            ablator = LayerIntegratedGradients(self.custom_forward, layer)
            input_tensor = input_ids
            attention_mask = attention_mask
            attributions = ablator.attribute(
                    inputs=input_ids,
                    baselines=self.ref_token_id,
                    additional_forward_args=(attention_mask),
                    target=index,
                    n_steps=steps,
            )
            return self.summarize_attributions(attributions).unsqueeze_(0), output, index
        finally:
            pass

    def build_visualization(self, input_ids, attention_mask, **kwargs):
        vis_data_records = []
        attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs)
        for record in range(input_ids.size(0)):
            classification = output[record].argmax(dim=-1).item()
            class_name = classifications[classification]
            attr = attributions[record]
            tokens = self.tokens_from_ids(input_ids[record].flatten())[
                1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
            ]
            vis_data_records.append(
                visualization.VisualizationDataRecord(
                    attr,
                    output[record][classification],
                    classification,
                    classification,
                    index,
                    1,
                    tokens,
                    1,
                )
            )
        return visualize_text(vis_data_records)

    def __call__(self, input_text, layer):
        text_batch = [input_text]
        encoding = self.tokenizer(text_batch, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)
        layer = int(layer)
        if layer == 0:
            layer = self.model.roberta.embeddings
        else:
            layer = getattr(self.model.roberta.encoder.layer, str(layer-1))

        return self.build_visualization(input_ids, attention_mask, layer=layer)