Spaces:
Runtime error
Runtime error
import os | |
import random | |
import zipfile | |
from difflib import Differ | |
import gradio as gr | |
import nltk | |
import pandas as pd | |
from findfile import find_files | |
from anonymous_demo import TADCheckpointManager | |
from textattack import Attacker | |
from textattack.attack_recipes import ( | |
BAEGarg2019, | |
PWWSRen2019, | |
TextFoolerJin2019, | |
PSOZang2020, | |
IGAWang2019, | |
GeneticAlgorithmAlzantot2018, | |
DeepWordBugGao2018, | |
CLARE2020, | |
) | |
from textattack.attack_results import SuccessfulAttackResult | |
from textattack.datasets import Dataset | |
from textattack.models.wrappers import HuggingFaceModelWrapper | |
z = zipfile.ZipFile("checkpoints.zip", "r") | |
z.extractall(os.getcwd()) | |
class ModelWrapper(HuggingFaceModelWrapper): | |
def __init__(self, model): | |
self.model = model # pipeline = pipeline | |
def __call__(self, text_inputs, **kwargs): | |
outputs = [] | |
for text_input in text_inputs: | |
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) | |
outputs.append(raw_outputs["probs"]) | |
return outputs | |
class SentAttacker: | |
def __init__(self, model, recipe_class=BAEGarg2019): | |
model = model | |
model_wrapper = ModelWrapper(model) | |
recipe = recipe_class.build(model_wrapper) | |
# WordNet defaults to english. Set the default language to French ('fra') | |
# recipe.transformation.language = "en" | |
_dataset = [("", 0)] | |
_dataset = Dataset(_dataset) | |
self.attacker = Attacker(recipe, _dataset) | |
def diff_texts(text1, text2): | |
d = Differ() | |
return [ | |
(token[2:], token[0] if token[0] != " " else None) | |
for token in d.compare(text1, text2) | |
] | |
def get_ensembled_tad_results(results): | |
target_dict = {} | |
for r in results: | |
target_dict[r["label"]] = ( | |
target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1 | |
) | |
return dict(zip(target_dict.values(), target_dict.keys()))[ | |
max(target_dict.values()) | |
] | |
nltk.download("omw-1.4") | |
sent_attackers = {} | |
tad_classifiers = {} | |
attack_recipes = { | |
"bae": BAEGarg2019, | |
"pwws": PWWSRen2019, | |
"textfooler": TextFoolerJin2019, | |
"pso": PSOZang2020, | |
"iga": IGAWang2019, | |
"ga": GeneticAlgorithmAlzantot2018, | |
"deepwordbug": DeepWordBugGao2018, | |
'clare': CLARE2020, | |
} | |
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]: | |
for dataset in [ | |
"agnews10k", | |
"amazon", | |
"sst2", | |
# 'imdb' | |
]: | |
if "tad-{}".format(dataset) not in tad_classifiers: | |
tad_classifiers[ | |
"tad-{}".format(dataset) | |
] = TADCheckpointManager.get_tad_text_classifier( | |
"tad-{}".format(dataset).upper() | |
) | |
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker( | |
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker] | |
) | |
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[ | |
"tad-{}pwws".format(dataset) | |
] | |
def get_sst2_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "sst2" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return data[random.randint(0, len(data))] | |
def get_agnews_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "agnews" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return data[random.randint(0, len(data))] | |
def get_amazon_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "amazon" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return data[random.randint(0, len(data))] | |
def get_imdb_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "imdb" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return data[random.randint(0, len(data))] | |
cache = set() | |
def generate_adversarial_example(dataset, attacker, text=None, label=None): | |
if not text or text in cache: | |
if "agnews" in dataset.lower(): | |
text, label = get_agnews_example() | |
elif "sst2" in dataset.lower(): | |
text, label = get_sst2_example() | |
elif "amazon" in dataset.lower(): | |
text, label = get_amazon_example() | |
elif "imdb" in dataset.lower(): | |
text, label = get_imdb_example() | |
cache.add(text) | |
result = None | |
attack_result = sent_attackers[ | |
"tad-{}{}".format(dataset.lower(), attacker.lower()) | |
].attacker.simple_attack(text, int(label)) | |
if isinstance(attack_result, SuccessfulAttackResult): | |
if ( | |
attack_result.perturbed_result.output | |
!= attack_result.original_result.ground_truth_output | |
) and ( | |
attack_result.original_result.output | |
== attack_result.original_result.ground_truth_output | |
): | |
# with defense | |
result = tad_classifiers["tad-{}".format(dataset.lower())].infer( | |
attack_result.perturbed_result.attacked_text.text | |
+ "!ref!{},{},{}".format( | |
attack_result.original_result.ground_truth_output, | |
1, | |
attack_result.perturbed_result.output, | |
), | |
print_result=True, | |
defense="pwws", | |
) | |
if result: | |
classification_df = {} | |
classification_df["is_repaired"] = result["is_fixed"] | |
classification_df["pred_label"] = result["label"] | |
classification_df["confidence"] = round(result["confidence"], 3) | |
classification_df["is_correct"] = result["ref_label_check"] | |
advdetection_df = {} | |
if result["is_adv_label"] != "0": | |
advdetection_df["is_adversarial"] = { | |
"0": False, | |
"1": True, | |
0: False, | |
1: True, | |
}[result["is_adv_label"]] | |
advdetection_df["perturbed_label"] = result["perturbed_label"] | |
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3) | |
# advdetection_df['ref_is_attack'] = result['ref_is_adv_label'] | |
# advdetection_df['is_correct'] = result['ref_is_adv_check'] | |
else: | |
return generate_adversarial_example(dataset, attacker) | |
return ( | |
text, | |
label, | |
result["restored_text"], | |
result["label"], | |
attack_result.perturbed_result.attacked_text.text, | |
diff_texts(text, text), | |
diff_texts(text, attack_result.perturbed_result.attacked_text.text), | |
diff_texts(text, result["restored_text"]), | |
attack_result.perturbed_result.output, | |
pd.DataFrame(classification_df, index=[0]), | |
pd.DataFrame(advdetection_df, index=[0]), | |
) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
"# <p align='center'> Reactive Perturbation Defocusing for Textual Adversarial Defense </p> " | |
) | |
gr.Markdown("## <p align='center'>Clarifications</p>") | |
gr.Markdown( | |
"- This demo has no mechanism to ensure the adversarial example will be correctly repaired by RPD." | |
" The repair success rate is actually the performance reported in the paper (approximately up to 97%.)" | |
) | |
gr.Markdown( | |
"- The red (+) and green (-) colors in the character edition indicate the character is added " | |
"or deleted in the adversarial example compared to the original input natural example." | |
) | |
gr.Markdown( | |
"- The adversarial example and repaired adversarial example may be unnatural to read, " | |
"while it is because the attackers usually generate unnatural perturbations." | |
"RPD does not introduce additional unnatural perturbations." | |
) | |
gr.Markdown( | |
"- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense " | |
". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods." | |
) | |
gr.Markdown( | |
"- The DeepWordBug is an unknown attacker to RPD's adversarial detector, which shows the robustness of RPD." | |
) | |
gr.Markdown("## <p align='center'>Natural Example Input</p>") | |
with gr.Group(): | |
with gr.Row(): | |
input_dataset = gr.Radio( | |
choices=["SST2", "AGNews10K", "Amazon"], | |
value="SST2", | |
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.", | |
) | |
input_attacker = gr.Radio( | |
choices=[ | |
"BAE", | |
"PWWS", | |
"TextFooler", | |
"DeepWordBug" | |
], | |
value="TextFooler", | |
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.", | |
) | |
with gr.Group(): | |
with gr.Row(): | |
input_sentence = gr.Textbox( | |
placeholder="Input a natural example...", | |
label="Alternatively, input a natural example and its original label to generate an adversarial example.", | |
) | |
input_label = gr.Textbox( | |
placeholder="Original label...", label="Original Label" | |
) | |
button_gen = gr.Button( | |
"Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )", | |
variant="primary", | |
) | |
gr.Markdown( | |
"## <p align='center'>Generated Adversarial Example and Repaired Adversarial Example</p>" | |
) | |
with gr.Group(): | |
with gr.Column(): | |
with gr.Row(): | |
output_original_example = gr.Textbox(label="Original Example") | |
output_original_label = gr.Textbox(label="Original Label") | |
with gr.Row(): | |
output_adv_example = gr.Textbox(label="Adversarial Example") | |
output_adv_label = gr.Textbox(label="Perturbed Label") | |
with gr.Row(): | |
output_repaired_example = gr.Textbox( | |
label="Repaired Adversarial Example by RPD" | |
) | |
output_repaired_label = gr.Textbox(label="Repaired Label") | |
gr.Markdown( | |
"## <p align='center'>The Output of Reactive Perturbation Defocusing</p>" | |
) | |
with gr.Group(): | |
output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result") | |
gr.Markdown( | |
"The is_adversarial field indicates an adversarial example is detected. " | |
"The perturbed_label is the predicted label of the adversarial example. " | |
"The confidence field represents the confidence of the predicted adversarial example detection. " | |
) | |
output_df = gr.DataFrame(label="Repaired Standard Classification Result") | |
gr.Markdown( | |
"If is_repaired=true, it has been repaired by RPD. " | |
"The pred_label field indicates the standard classification result. " | |
"The confidence field represents the confidence of the predicted label. " | |
"The is_correct field indicates whether the predicted label is correct." | |
) | |
gr.Markdown("## <p align='center'>Example Comparisons</p>") | |
ori_text_diff = gr.HighlightedText( | |
label="The Original Natural Example", | |
combine_adjacent=True, | |
) | |
adv_text_diff = gr.HighlightedText( | |
label="Character Editions of Adversarial Example Compared to the Natural Example", | |
combine_adjacent=True, | |
) | |
restored_text_diff = gr.HighlightedText( | |
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", | |
combine_adjacent=True, | |
) | |
# Bind functions to buttons | |
button_gen.click( | |
fn=generate_adversarial_example, | |
inputs=[input_dataset, input_attacker, input_sentence, input_label], | |
outputs=[ | |
output_original_example, | |
output_original_label, | |
output_repaired_example, | |
output_repaired_label, | |
output_adv_example, | |
ori_text_diff, | |
adv_text_diff, | |
restored_text_diff, | |
output_adv_label, | |
output_df, | |
output_is_adv_df, | |
], | |
) | |
demo.launch() | |