Spaces:
Running
Running
""" | |
Attacker Class | |
============== | |
""" | |
import collections | |
import logging | |
import multiprocessing as mp | |
import os | |
import queue | |
import random | |
import traceback | |
import torch | |
import tqdm | |
import textattack | |
from textattack.attack_results import ( | |
FailedAttackResult, | |
MaximizedAttackResult, | |
SkippedAttackResult, | |
SuccessfulAttackResult, | |
) | |
from textattack.shared.utils import logger | |
from .attack import Attack | |
from .attack_args import AttackArgs | |
class Attacker: | |
"""Class for running attacks on a dataset with specified parameters. This | |
class uses the :class:`~textattack.Attack` to actually run the attacks, | |
while also providing useful features such as parallel processing, | |
saving/resuming from a checkpint, logging to files and stdout. | |
Args: | |
attack (:class:`~textattack.Attack`): | |
:class:`~textattack.Attack` used to actually carry out the attack. | |
dataset (:class:`~textattack.datasets.Dataset`): | |
Dataset to attack. | |
attack_args (:class:`~textattack.AttackArgs`): | |
Arguments for attacking the dataset. For default settings, look at the `AttackArgs` class. | |
Example:: | |
>>> import textattack | |
>>> import transformers | |
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") | |
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") | |
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | |
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | |
>>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval | |
>>> attack_args = textattack.AttackArgs( | |
... num_examples=20, | |
... log_to_csv="log.csv", | |
... checkpoint_interval=5, | |
... checkpoint_dir="checkpoints", | |
... disable_stdout=True | |
... ) | |
>>> attacker = textattack.Attacker(attack, dataset, attack_args) | |
>>> attacker.attack_dataset() | |
""" | |
def __init__(self, attack, dataset, attack_args=None): | |
assert isinstance( | |
attack, Attack | |
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." | |
assert isinstance( | |
dataset, textattack.datasets.Dataset | |
), f"`dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(dataset)}`." | |
if attack_args: | |
assert isinstance( | |
attack_args, AttackArgs | |
), f"`attack_args` must be of type `textattack.AttackArgs`, but got type `{type(attack_args)}`." | |
else: | |
attack_args = AttackArgs() | |
self.attack = attack | |
self.dataset = dataset | |
self.attack_args = attack_args | |
self.attack_log_manager = None | |
# This is to be set if loading from a checkpoint | |
self._checkpoint = None | |
def _get_worklist(self, start, end, num_examples, shuffle): | |
if end - start < num_examples: | |
logger.warn( | |
f"Attempting to attack {num_examples} samples when only {end-start} are available." | |
) | |
candidates = list(range(start, end)) | |
if shuffle: | |
random.shuffle(candidates) | |
worklist = collections.deque(candidates[:num_examples]) | |
candidates = collections.deque(candidates[num_examples:]) | |
assert (len(worklist) + len(candidates)) == (end - start) | |
return worklist, candidates | |
def simple_attack(self, text, label): | |
"""Internal method that carries out attack. | |
No parallel processing is involved. | |
""" | |
if torch.cuda.is_available(): | |
self.attack.cuda_() | |
example, ground_truth_output = text, label | |
try: | |
example = textattack.shared.AttackedText(example) | |
if self.dataset.label_names is not None: | |
example.attack_attrs["label_names"] = self.dataset.label_names | |
try: | |
result = self.attack.attack(example, ground_truth_output) | |
except Exception as e: | |
raise e | |
# return | |
if ( | |
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n | |
) or ( | |
not isinstance(result, SuccessfulAttackResult) | |
and self.attack_args.num_successful_examples | |
): | |
return | |
else: | |
return result | |
except KeyboardInterrupt as e: | |
raise e | |
def _attack(self): | |
"""Internal method that carries out attack. | |
No parallel processing is involved. | |
""" | |
if torch.cuda.is_available(): | |
self.attack.cuda_() | |
if self._checkpoint: | |
num_remaining_attacks = self._checkpoint.num_remaining_attacks | |
worklist = self._checkpoint.worklist | |
worklist_candidates = self._checkpoint.worklist_candidates | |
logger.info( | |
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}." | |
) | |
else: | |
if self.attack_args.num_successful_examples: | |
num_remaining_attacks = self.attack_args.num_successful_examples | |
# We make `worklist` deque (linked-list) for easy pop and append. | |
# Candidates are other samples we can attack if we need more samples. | |
worklist, worklist_candidates = self._get_worklist( | |
self.attack_args.num_examples_offset, | |
len(self.dataset), | |
self.attack_args.num_successful_examples, | |
self.attack_args.shuffle, | |
) | |
else: | |
num_remaining_attacks = self.attack_args.num_examples | |
# We make `worklist` deque (linked-list) for easy pop and append. | |
# Candidates are other samples we can attack if we need more samples. | |
worklist, worklist_candidates = self._get_worklist( | |
self.attack_args.num_examples_offset, | |
len(self.dataset), | |
self.attack_args.num_examples, | |
self.attack_args.shuffle, | |
) | |
if not self.attack_args.silent: | |
print(self.attack, "\n") | |
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True) | |
if self._checkpoint: | |
num_results = self._checkpoint.results_count | |
num_failures = self._checkpoint.num_failed_attacks | |
num_skipped = self._checkpoint.num_skipped_attacks | |
num_successes = self._checkpoint.num_successful_attacks | |
else: | |
num_results = 0 | |
num_failures = 0 | |
num_skipped = 0 | |
num_successes = 0 | |
sample_exhaustion_warned = False | |
while worklist: | |
idx = worklist.popleft() | |
try: | |
example, ground_truth_output = self.dataset[idx] | |
except IndexError: | |
continue | |
example = textattack.shared.AttackedText(example) | |
if self.dataset.label_names is not None: | |
example.attack_attrs["label_names"] = self.dataset.label_names | |
try: | |
result = self.attack.attack(example, ground_truth_output) | |
except Exception as e: | |
raise e | |
if ( | |
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n | |
) or ( | |
not isinstance(result, SuccessfulAttackResult) | |
and self.attack_args.num_successful_examples | |
): | |
if worklist_candidates: | |
next_sample = worklist_candidates.popleft() | |
worklist.append(next_sample) | |
else: | |
if not sample_exhaustion_warned: | |
logger.warn("Ran out of samples to attack!") | |
sample_exhaustion_warned = True | |
else: | |
pbar.update(1) | |
self.attack_log_manager.log_result(result) | |
if not self.attack_args.disable_stdout and not self.attack_args.silent: | |
print("\n") | |
num_results += 1 | |
if isinstance(result, SkippedAttackResult): | |
num_skipped += 1 | |
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)): | |
num_successes += 1 | |
if isinstance(result, FailedAttackResult): | |
num_failures += 1 | |
pbar.set_description( | |
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}" | |
) | |
if ( | |
self.attack_args.checkpoint_interval | |
and len(self.attack_log_manager.results) | |
% self.attack_args.checkpoint_interval | |
== 0 | |
): | |
new_checkpoint = textattack.shared.AttackCheckpoint( | |
self.attack_args, | |
self.attack_log_manager, | |
worklist, | |
worklist_candidates, | |
) | |
new_checkpoint.save() | |
self.attack_log_manager.flush() | |
pbar.close() | |
print() | |
# Enable summary stdout | |
if not self.attack_args.silent and self.attack_args.disable_stdout: | |
self.attack_log_manager.enable_stdout() | |
if self.attack_args.enable_advance_metrics: | |
self.attack_log_manager.enable_advance_metrics = True | |
self.attack_log_manager.log_summary() | |
self.attack_log_manager.flush() | |
print() | |
def _attack_parallel(self): | |
pytorch_multiprocessing_workaround() | |
if self._checkpoint: | |
num_remaining_attacks = self._checkpoint.num_remaining_attacks | |
worklist = self._checkpoint.worklist | |
worklist_candidates = self._checkpoint.worklist_candidates | |
logger.info( | |
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}." | |
) | |
else: | |
if self.attack_args.num_successful_examples: | |
num_remaining_attacks = self.attack_args.num_successful_examples | |
# We make `worklist` deque (linked-list) for easy pop and append. | |
# Candidates are other samples we can attack if we need more samples. | |
worklist, worklist_candidates = self._get_worklist( | |
self.attack_args.num_examples_offset, | |
len(self.dataset), | |
self.attack_args.num_successful_examples, | |
self.attack_args.shuffle, | |
) | |
else: | |
num_remaining_attacks = self.attack_args.num_examples | |
# We make `worklist` deque (linked-list) for easy pop and append. | |
# Candidates are other samples we can attack if we need more samples. | |
worklist, worklist_candidates = self._get_worklist( | |
self.attack_args.num_examples_offset, | |
len(self.dataset), | |
self.attack_args.num_examples, | |
self.attack_args.shuffle, | |
) | |
in_queue = torch.multiprocessing.Queue() | |
out_queue = torch.multiprocessing.Queue() | |
for i in worklist: | |
try: | |
example, ground_truth_output = self.dataset[i] | |
example = textattack.shared.AttackedText(example) | |
if self.dataset.label_names is not None: | |
example.attack_attrs["label_names"] = self.dataset.label_names | |
in_queue.put((i, example, ground_truth_output)) | |
except IndexError: | |
raise IndexError( | |
f"Tried to access element at {i} in dataset of size {len(self.dataset)}." | |
) | |
# We reserve the first GPU for coordinating workers. | |
num_gpus = torch.cuda.device_count() | |
num_workers = self.attack_args.num_workers_per_device * num_gpus | |
logger.info(f"Running {num_workers} worker(s) on {num_gpus} GPU(s).") | |
# Lock for synchronization | |
lock = mp.Lock() | |
# We move Attacker (and its components) to CPU b/c we don't want models using wrong GPU in worker processes. | |
self.attack.cpu_() | |
torch.cuda.empty_cache() | |
# Start workers. | |
worker_pool = torch.multiprocessing.Pool( | |
num_workers, | |
attack_from_queue, | |
( | |
self.attack, | |
self.attack_args, | |
num_gpus, | |
mp.Value("i", 1, lock=False), | |
lock, | |
in_queue, | |
out_queue, | |
), | |
) | |
# Log results asynchronously and update progress bar. | |
if self._checkpoint: | |
num_results = self._checkpoint.results_count | |
num_failures = self._checkpoint.num_failed_attacks | |
num_skipped = self._checkpoint.num_skipped_attacks | |
num_successes = self._checkpoint.num_successful_attacks | |
else: | |
num_results = 0 | |
num_failures = 0 | |
num_skipped = 0 | |
num_successes = 0 | |
logger.info(f"Worklist size: {len(worklist)}") | |
logger.info(f"Worklist candidate size: {len(worklist_candidates)}") | |
sample_exhaustion_warned = False | |
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True) | |
while worklist: | |
idx, result = out_queue.get(block=True) | |
worklist.remove(idx) | |
if isinstance(result, tuple) and isinstance(result[0], Exception): | |
logger.error( | |
f'Exception encountered for input "{self.dataset[idx][0]}".' | |
) | |
error_trace = result[1] | |
logger.error(error_trace) | |
in_queue.close() | |
in_queue.join_thread() | |
out_queue.close() | |
out_queue.join_thread() | |
worker_pool.terminate() | |
worker_pool.join() | |
return | |
elif ( | |
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n | |
) or ( | |
not isinstance(result, SuccessfulAttackResult) | |
and self.attack_args.num_successful_examples | |
): | |
if worklist_candidates: | |
next_sample = worklist_candidates.popleft() | |
example, ground_truth_output = self.dataset[next_sample] | |
example = textattack.shared.AttackedText(example) | |
if self.dataset.label_names is not None: | |
example.attack_attrs["label_names"] = self.dataset.label_names | |
worklist.append(next_sample) | |
in_queue.put((next_sample, example, ground_truth_output)) | |
else: | |
if not sample_exhaustion_warned: | |
logger.warn("Ran out of samples to attack!") | |
sample_exhaustion_warned = True | |
else: | |
pbar.update() | |
self.attack_log_manager.log_result(result) | |
num_results += 1 | |
if isinstance(result, SkippedAttackResult): | |
num_skipped += 1 | |
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)): | |
num_successes += 1 | |
if isinstance(result, FailedAttackResult): | |
num_failures += 1 | |
pbar.set_description( | |
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}" | |
) | |
if ( | |
self.attack_args.checkpoint_interval | |
and len(self.attack_log_manager.results) | |
% self.attack_args.checkpoint_interval | |
== 0 | |
): | |
new_checkpoint = textattack.shared.AttackCheckpoint( | |
self.attack_args, | |
self.attack_log_manager, | |
worklist, | |
worklist_candidates, | |
) | |
new_checkpoint.save() | |
self.attack_log_manager.flush() | |
# Send sentinel values to worker processes | |
for _ in range(num_workers): | |
in_queue.put(("END", "END", "END")) | |
worker_pool.close() | |
worker_pool.join() | |
pbar.close() | |
print() | |
# Enable summary stdout. | |
if not self.attack_args.silent and self.attack_args.disable_stdout: | |
self.attack_log_manager.enable_stdout() | |
if self.attack_args.enable_advance_metrics: | |
self.attack_log_manager.enable_advance_metrics = True | |
self.attack_log_manager.log_summary() | |
self.attack_log_manager.flush() | |
print() | |
def attack_dataset(self): | |
"""Attack the dataset. | |
Returns: | |
:obj:`list[AttackResult]` - List of :class:`~textattack.attack_results.AttackResult` obtained after attacking the given dataset.. | |
""" | |
if self.attack_args.silent: | |
logger.setLevel(logging.ERROR) | |
if self.attack_args.query_budget: | |
self.attack.goal_function.query_budget = self.attack_args.query_budget | |
if not self.attack_log_manager: | |
self.attack_log_manager = AttackArgs.create_loggers_from_args( | |
self.attack_args | |
) | |
textattack.shared.utils.set_seed(self.attack_args.random_seed) | |
if self.dataset.shuffled and self.attack_args.checkpoint_interval: | |
# Not allowed b/c we cannot recover order of shuffled data | |
raise ValueError( | |
"Cannot use `--checkpoint-interval` with dataset that has been internally shuffled." | |
) | |
self.attack_args.num_examples = ( | |
len(self.dataset) | |
if self.attack_args.num_examples == -1 | |
else self.attack_args.num_examples | |
) | |
if self.attack_args.parallel: | |
if torch.cuda.device_count() == 0: | |
raise Exception( | |
"Found no GPU on your system. To run attacks in parallel, GPU is required." | |
) | |
self._attack_parallel() | |
else: | |
self._attack() | |
if self.attack_args.silent: | |
logger.setLevel(logging.INFO) | |
return self.attack_log_manager.results | |
def update_attack_args(self, **kwargs): | |
"""To update any attack args, pass the new argument as keyword argument | |
to this function. | |
Examples:: | |
>>> attacker = #some instance of Attacker | |
>>> # To switch to parallel mode and increase checkpoint interval from 100 to 500 | |
>>> attacker.update_attack_args(parallel=True, checkpoint_interval=500) | |
""" | |
for k in kwargs: | |
if hasattr(self.attack_args, k): | |
self.attack_args.k = kwargs[k] | |
else: | |
raise ValueError(f"`textattack.AttackArgs` does not have field {k}.") | |
def from_checkpoint(cls, attack, dataset, checkpoint): | |
"""Resume attacking from a saved checkpoint. Attacker and dataset must | |
be recovered by the user again, while attack args are loaded from the | |
saved checkpoint. | |
Args: | |
attack (:class:`~textattack.Attack`): | |
Attack object for carrying out the attack. | |
dataset (:class:`~textattack.datasets.Dataset`): | |
Dataset to attack. | |
checkpoint (:obj:`Union[str, :class:`~textattack.shared.AttackChecpoint`]`): | |
Path of saved checkpoint or the actual saved checkpoint. | |
""" | |
assert isinstance( | |
checkpoint, (str, textattack.shared.AttackCheckpoint) | |
), f"`checkpoint` must be of type `str` or `textattack.shared.AttackCheckpoint`, but got type `{type(checkpoint)}`." | |
if isinstance(checkpoint, str): | |
checkpoint = textattack.shared.AttackCheckpoint.load(checkpoint) | |
attacker = cls(attack, dataset, checkpoint.attack_args) | |
attacker.attack_log_manager = checkpoint.attack_log_manager | |
attacker._checkpoint = checkpoint | |
return attacker | |
def attack_interactive(attack): | |
print(attack, "\n") | |
print("Running in interactive mode") | |
print("----------------------------") | |
while True: | |
print('Enter a sentence to attack or "q" to quit:') | |
text = input() | |
if text == "q": | |
break | |
if not text: | |
continue | |
print("Attacking...") | |
example = textattack.shared.attacked_text.AttackedText(text) | |
output = attack.goal_function.get_output(example) | |
result = attack.attack(example, output) | |
print(result.__str__(color_method="ansi") + "\n") | |
# | |
# Helper Methods for multiprocess attacks | |
# | |
def pytorch_multiprocessing_workaround(): | |
# This is a fix for a known bug | |
try: | |
torch.multiprocessing.set_start_method("spawn", force=True) | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
except RuntimeError: | |
pass | |
def set_env_variables(gpu_id): | |
# Disable tensorflow logs, except in the case of an error. | |
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ: | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
# Set sharing strategy to file_system to avoid file descriptor leaks | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
# Only use one GPU, if we have one. | |
# For Tensorflow | |
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518# | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | |
# For PyTorch | |
torch.cuda.set_device(gpu_id) | |
# Fix TensorFlow GPU memory growth | |
try: | |
import tensorflow as tf | |
gpus = tf.config.experimental.list_physical_devices("GPU") | |
if gpus: | |
try: | |
# Currently, memory growth needs to be the same across GPUs | |
gpu = gpus[gpu_id] | |
tf.config.experimental.set_visible_devices(gpu, "GPU") | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except RuntimeError as e: | |
print(e) | |
except ModuleNotFoundError: | |
pass | |
def attack_from_queue( | |
attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue | |
): | |
assert isinstance( | |
attack, Attack | |
), f"`attack` must be of type `Attack`, but got type `{type(attack)}`." | |
gpu_id = (torch.multiprocessing.current_process()._identity[0] - 1) % num_gpus | |
set_env_variables(gpu_id) | |
textattack.shared.utils.set_seed(attack_args.random_seed) | |
if torch.multiprocessing.current_process()._identity[0] > 1: | |
logging.disable() | |
attack.cuda_() | |
# Simple non-synchronized check to see if it's the first process to reach this point. | |
# This let us avoid waiting for lock. | |
if bool(first_to_start.value): | |
# If it's first process to reach this step, we first try to acquire the lock to update the value. | |
with lock: | |
# Because another process could have changed `first_to_start=False` while we wait, we check again. | |
if bool(first_to_start.value): | |
first_to_start.value = 0 | |
if not attack_args.silent: | |
print(attack, "\n") | |
while True: | |
try: | |
i, example, ground_truth_output = in_queue.get(timeout=5) | |
if i == "END" and example == "END" and ground_truth_output == "END": | |
# End process when sentinel value is received | |
break | |
else: | |
result = attack.attack(example, ground_truth_output) | |
out_queue.put((i, result)) | |
except Exception as e: | |
if isinstance(e, queue.Empty): | |
continue | |
else: | |
out_queue.put((i, (e, traceback.format_exc()))) | |