|
from anonymous_demo import TADCheckpointManager |
|
|
|
from textattack.model_args import DEMO_MODELS |
|
from textattack.reactive_defense.reactive_defender import ReactiveDefender |
|
|
|
|
|
class TADReactiveDefender(ReactiveDefender): |
|
"""Transformers sentiment analysis pipeline returns a list of responses |
|
like |
|
|
|
[{'label': 'POSITIVE', 'score': 0.7817379832267761}] |
|
|
|
We need to convert that to a format TextAttack understands, like |
|
|
|
[[0.218262017, 0.7817379832267761] |
|
""" |
|
|
|
def __init__(self, ckpt="tad-sst2", **kwargs): |
|
super().__init__(**kwargs) |
|
self.tad_classifier = TADCheckpointManager.get_tad_text_classifier( |
|
checkpoint=DEMO_MODELS[ckpt], auto_device=True |
|
) |
|
|
|
def reactive_defense(self, text, **kwargs): |
|
res = self.tad_classifier.infer( |
|
text, defense="pwws", print_result=False, **kwargs |
|
) |
|
return res |
|
|