|
""" |
|
|
|
Perplexity Metric: |
|
------------------------------------------------------- |
|
Class for calculating perplexity from AttackResults |
|
|
|
""" |
|
|
|
import torch |
|
|
|
from textattack.attack_results import FailedAttackResult, SkippedAttackResult |
|
from textattack.metrics import Metric |
|
import textattack.shared.utils |
|
|
|
|
|
class Perplexity(Metric): |
|
def __init__(self, model_name="gpt2"): |
|
self.all_metrics = {} |
|
self.original_candidates = [] |
|
self.successful_candidates = [] |
|
|
|
if model_name == "gpt2": |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
self.ppl_model.to(textattack.shared.utils.device) |
|
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
self.ppl_model.eval() |
|
self.max_length = self.ppl_model.config.n_positions |
|
else: |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
|
self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.ppl_model.to(textattack.shared.utils.device) |
|
self.ppl_model.eval() |
|
self.max_length = self.ppl_model.config.max_position_embeddings |
|
|
|
self.stride = 512 |
|
|
|
def calculate(self, results): |
|
"""Calculates average Perplexity on all successfull attacks using a |
|
pre-trained small GPT-2 model. |
|
|
|
Args: |
|
results (``AttackResult`` objects): |
|
Attack results for each instance in dataset |
|
|
|
Example:: |
|
|
|
|
|
>> import textattack |
|
>> import transformers |
|
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
|
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
|
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |
|
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) |
|
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") |
|
>> attack_args = textattack.AttackArgs( |
|
num_examples=1, |
|
log_to_csv="log.csv", |
|
checkpoint_interval=5, |
|
checkpoint_dir="checkpoints", |
|
disable_stdout=True |
|
) |
|
>> attacker = textattack.Attacker(attack, dataset, attack_args) |
|
>> results = attacker.attack_dataset() |
|
>> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results) |
|
""" |
|
self.results = results |
|
self.original_candidates_ppl = [] |
|
self.successful_candidates_ppl = [] |
|
|
|
for i, result in enumerate(self.results): |
|
if isinstance(result, FailedAttackResult): |
|
continue |
|
elif isinstance(result, SkippedAttackResult): |
|
continue |
|
else: |
|
self.original_candidates.append( |
|
result.original_result.attacked_text.text.lower() |
|
) |
|
self.successful_candidates.append( |
|
result.perturbed_result.attacked_text.text.lower() |
|
) |
|
|
|
ppl_orig = self.calc_ppl(self.original_candidates) |
|
ppl_attack = self.calc_ppl(self.successful_candidates) |
|
|
|
self.all_metrics["avg_original_perplexity"] = round(ppl_orig, 2) |
|
|
|
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack, 2) |
|
|
|
return self.all_metrics |
|
|
|
def calc_ppl(self, texts): |
|
with torch.no_grad(): |
|
text = " ".join(texts) |
|
eval_loss = [] |
|
input_ids = torch.tensor( |
|
self.ppl_tokenizer.encode(text, add_special_tokens=True) |
|
).unsqueeze(0) |
|
|
|
for i in range(0, input_ids.size(1), self.stride): |
|
begin_loc = max(i + self.stride - self.max_length, 0) |
|
end_loc = min(i + self.stride, input_ids.size(1)) |
|
trg_len = end_loc - i |
|
input_ids_t = input_ids[:, begin_loc:end_loc].to( |
|
textattack.shared.utils.device |
|
) |
|
target_ids = input_ids_t.clone() |
|
target_ids[:, :-trg_len] = -100 |
|
|
|
outputs = self.ppl_model(input_ids_t, labels=target_ids) |
|
log_likelihood = outputs[0] * trg_len |
|
|
|
eval_loss.append(log_likelihood) |
|
|
|
return torch.exp(torch.stack(eval_loss).sum() / end_loc).item() |
|
|