Source code for textattack.shared.checkpoint

Misc Checkpoints

The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk.
import copy
import datetime
import os
import pickle
import time

import textattack
from textattack.attack_results import (
from textattack.shared import logger, utils

# TODO: Consider still keeping the old `Checkpoint` class and allow older checkpoints to be loaded to new TextAttack

[docs]class AttackCheckpoint: """An object that stores necessary information for saving and loading checkpoints. Args: attack_args (textattack.AttackArgs): Arguments of the original attack attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset. worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`. chkpt_time (float): epoch time representing when checkpoint was made """ def __init__( self, attack_args, attack_log_manager, worklist, worklist_candidates, chkpt_time=None, ): assert isinstance( attack_args, textattack.AttackArgs ), "`attack_args` must be of type `textattack.AttackArgs`." assert isinstance( attack_log_manager, textattack.loggers.AttackLogManager ), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`." self.attack_args = copy.deepcopy(attack_args) self.attack_log_manager = attack_log_manager self.worklist = worklist self.worklist_candidates = worklist_candidates if chkpt_time: self.time = chkpt_time else: self.time = time.time() self._verify() def __repr__(self): main_str = "AttackCheckpoint(" lines = [] lines.append(utils.add_indent(f"(Time): {self.datetime}", 2)) args_lines = [] recipe_set = ( True if "recipe" in self.attack_args.__dict__ and self.attack_args.__dict__["recipe"] else False ) mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"] if recipe_set: args_lines.append( utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2) ) else: args_lines.append( utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2) ) args_lines.append( utils.add_indent( f'(transformation): {self.attack_args.__dict__["transformation"]}', 2, ) ) args_lines.append( utils.add_indent( f'(constraints): {self.attack_args.__dict__["constraints"]}', 2 ) ) for key in self.attack_args.__dict__: if key not in mutually_exclusive_args: args_lines.append( utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2) ) args_str = utils.add_indent("\n" + "\n".join(args_lines), 2) lines.append(utils.add_indent(f"(attack_args): {args_str}", 2)) attack_logger_lines = [] attack_logger_lines.append( utils.add_indent( f"(Total number of examples to attack): {self.attack_args.num_examples}", 2, ) ) attack_logger_lines.append( utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2) ) attack_logger_lines.append( utils.add_indent( f"(Number of remaining attacks): {self.num_remaining_attacks}", 2 ) ) breakdown_lines = [] breakdown_lines.append( utils.add_indent( f"(Number of successful attacks): {self.num_successful_attacks}", 2 ) ) breakdown_lines.append( utils.add_indent( f"(Number of failed attacks): {self.num_failed_attacks}", 2 ) ) breakdown_lines.append( utils.add_indent( f"(Number of maximized attacks): {self.num_maximized_attacks}", 2 ) ) breakdown_lines.append( utils.add_indent( f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 ) ) breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2) attack_logger_lines.append( utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2) ) attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2) lines.append( utils.add_indent(f"(Previous attack summary): {attack_logger_str}", 2) ) main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str __str__ = __repr__ @property def results_count(self): """Return number of attacks made so far.""" return len(self.attack_log_manager.results) @property def num_skipped_attacks(self): return sum( isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results ) @property def num_failed_attacks(self): return sum( isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results ) @property def num_successful_attacks(self): return sum( isinstance(r, SuccessfulAttackResult) for r in self.attack_log_manager.results ) @property def num_maximized_attacks(self): return sum( isinstance(r, MaximizedAttackResult) for r in self.attack_log_manager.results ) @property def num_remaining_attacks(self): if self.attack_args.attack_n: non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks count = self.attack_args.num_examples - non_skipped_attacks else: count = self.attack_args.num_examples - self.results_count return count @property def dataset_offset(self): """Calculate offset into the dataset to start from.""" # Original offset + # of results processed so far return self.attack_args.num_examples_offset + self.results_count @property def datetime(self): return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S")
[docs] def save(self, quiet=False): file_name = "{}.ta.chkpt".format(int(self.time * 1000)) if not os.path.exists(self.attack_args.checkpoint_dir): os.makedirs(self.attack_args.checkpoint_dir) path = os.path.join(self.attack_args.checkpoint_dir, file_name) if not quiet: print("\n\n" + "=" * 125) 'Saving checkpoint under "{}" at {} after {} attacks.'.format( path, self.datetime, self.results_count ) ) print("=" * 125 + "\n") with open(path, "wb") as f: pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
[docs] @classmethod def load(cls, path): with open(path, "rb") as f: checkpoint = pickle.load(f) assert isinstance(checkpoint, cls) return checkpoint
def _verify(self): """Check that the checkpoint has no duplicates and is consistent.""" assert self.num_remaining_attacks == len( self.worklist ), "Recorded number of remaining attacks and size of worklist are different." results_set = { result.original_text for result in self.attack_log_manager.results } assert ( len(results_set) == self.results_count ), "Duplicate `AttackResults` found."