Spaces:
Runtime error
Runtime error
""" | |
Misc Checkpoints | |
=================== | |
The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk. | |
""" | |
import copy | |
import datetime | |
import os | |
import pickle | |
import time | |
import textattack | |
from textattack.attack_results import ( | |
FailedAttackResult, | |
MaximizedAttackResult, | |
SkippedAttackResult, | |
SuccessfulAttackResult, | |
) | |
from textattack.shared import logger, utils | |
# TODO: Consider still keeping the old `Checkpoint` class and allow older checkpoints to be loaded to new TextAttack | |
class AttackCheckpoint: | |
"""An object that stores necessary information for saving and loading | |
checkpoints. | |
Args: | |
attack_args (textattack.AttackArgs): Arguments of the original attack | |
attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results | |
worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset. | |
worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`. | |
chkpt_time (float): epoch time representing when checkpoint was made | |
""" | |
def __init__( | |
self, | |
attack_args, | |
attack_log_manager, | |
worklist, | |
worklist_candidates, | |
chkpt_time=None, | |
): | |
assert isinstance( | |
attack_args, textattack.AttackArgs | |
), "`attack_args` must be of type `textattack.AttackArgs`." | |
assert isinstance( | |
attack_log_manager, textattack.loggers.AttackLogManager | |
), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`." | |
self.attack_args = copy.deepcopy(attack_args) | |
self.attack_log_manager = attack_log_manager | |
self.worklist = worklist | |
self.worklist_candidates = worklist_candidates | |
if chkpt_time: | |
self.time = chkpt_time | |
else: | |
self.time = time.time() | |
self._verify() | |
def __repr__(self): | |
main_str = "AttackCheckpoint(" | |
lines = [] | |
lines.append(utils.add_indent(f"(Time): {self.datetime}", 2)) | |
args_lines = [] | |
recipe_set = ( | |
True | |
if "recipe" in self.attack_args.__dict__ | |
and self.attack_args.__dict__["recipe"] | |
else False | |
) | |
mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"] | |
if recipe_set: | |
args_lines.append( | |
utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2) | |
) | |
else: | |
args_lines.append( | |
utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2) | |
) | |
args_lines.append( | |
utils.add_indent( | |
f'(transformation): {self.attack_args.__dict__["transformation"]}', | |
2, | |
) | |
) | |
args_lines.append( | |
utils.add_indent( | |
f'(constraints): {self.attack_args.__dict__["constraints"]}', 2 | |
) | |
) | |
for key in self.attack_args.__dict__: | |
if key not in mutually_exclusive_args: | |
args_lines.append( | |
utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2) | |
) | |
args_str = utils.add_indent("\n" + "\n".join(args_lines), 2) | |
lines.append(utils.add_indent(f"(attack_args): {args_str}", 2)) | |
attack_logger_lines = [] | |
attack_logger_lines.append( | |
utils.add_indent( | |
f"(Total number of examples to attack): {self.attack_args.num_examples}", | |
2, | |
) | |
) | |
attack_logger_lines.append( | |
utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2) | |
) | |
attack_logger_lines.append( | |
utils.add_indent( | |
f"(Number of remaining attacks): {self.num_remaining_attacks}", 2 | |
) | |
) | |
breakdown_lines = [] | |
breakdown_lines.append( | |
utils.add_indent( | |
f"(Number of successful attacks): {self.num_successful_attacks}", 2 | |
) | |
) | |
breakdown_lines.append( | |
utils.add_indent( | |
f"(Number of failed attacks): {self.num_failed_attacks}", 2 | |
) | |
) | |
breakdown_lines.append( | |
utils.add_indent( | |
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2 | |
) | |
) | |
breakdown_lines.append( | |
utils.add_indent( | |
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 | |
) | |
) | |
breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2) | |
attack_logger_lines.append( | |
utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2) | |
) | |
attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2) | |
lines.append( | |
utils.add_indent(f"(Previous attack summary): {attack_logger_str}", 2) | |
) | |
main_str += "\n " + "\n ".join(lines) + "\n" | |
main_str += ")" | |
return main_str | |
__str__ = __repr__ | |
def results_count(self): | |
"""Return number of attacks made so far.""" | |
return len(self.attack_log_manager.results) | |
def num_skipped_attacks(self): | |
return sum( | |
isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results | |
) | |
def num_failed_attacks(self): | |
return sum( | |
isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results | |
) | |
def num_successful_attacks(self): | |
return sum( | |
isinstance(r, SuccessfulAttackResult) | |
for r in self.attack_log_manager.results | |
) | |
def num_maximized_attacks(self): | |
return sum( | |
isinstance(r, MaximizedAttackResult) | |
for r in self.attack_log_manager.results | |
) | |
def num_remaining_attacks(self): | |
if self.attack_args.attack_n: | |
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks | |
count = self.attack_args.num_examples - non_skipped_attacks | |
else: | |
count = self.attack_args.num_examples - self.results_count | |
return count | |
def dataset_offset(self): | |
"""Calculate offset into the dataset to start from.""" | |
# Original offset + # of results processed so far | |
return self.attack_args.num_examples_offset + self.results_count | |
def datetime(self): | |
return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S") | |
def save(self, quiet=False): | |
file_name = "{}.ta.chkpt".format(int(self.time * 1000)) | |
if not os.path.exists(self.attack_args.checkpoint_dir): | |
os.makedirs(self.attack_args.checkpoint_dir) | |
path = os.path.join(self.attack_args.checkpoint_dir, file_name) | |
if not quiet: | |
print("\n\n" + "=" * 125) | |
logger.info( | |
'Saving checkpoint under "{}" at {} after {} attacks.'.format( | |
path, self.datetime, self.results_count | |
) | |
) | |
print("=" * 125 + "\n") | |
with open(path, "wb") as f: | |
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) | |
def load(cls, path): | |
with open(path, "rb") as f: | |
checkpoint = pickle.load(f) | |
assert isinstance(checkpoint, cls) | |
return checkpoint | |
def _verify(self): | |
"""Check that the checkpoint has no duplicates and is consistent.""" | |
assert self.num_remaining_attacks == len( | |
self.worklist | |
), "Recorded number of remaining attacks and size of worklist are different." | |
results_set = { | |
result.original_text for result in self.attack_log_manager.results | |
} | |
assert ( | |
len(results_set) == self.results_count | |
), "Duplicate `AttackResults` found." | |