File size: 900 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from anonymous_demo import TADCheckpointManager

from textattack.model_args import DEMO_MODELS
from textattack.reactive_defense.reactive_defender import ReactiveDefender


class TADReactiveDefender(ReactiveDefender):
    """Transformers sentiment analysis pipeline returns a list of responses
    like

        [{'label': 'POSITIVE', 'score': 0.7817379832267761}]

    We need to convert that to a format TextAttack understands, like

        [[0.218262017, 0.7817379832267761]
    """

    def __init__(self, ckpt="tad-sst2", **kwargs):
        super().__init__(**kwargs)
        self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(
            checkpoint=DEMO_MODELS[ckpt], auto_device=True
        )

    def reactive_defense(self, text, **kwargs):
        res = self.tad_classifier.infer(
            text, defense="pwws", print_result=False, **kwargs
        )
        return res