Source code for textattack.commands.attack_resume_command

"""

AttackResumeCommand class
===========================

"""

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os

import textattack
from textattack import Attacker, CommandLineAttackArgs, DatasetArgs, ModelArgs
from textattack.commands import TextAttackCommand


[docs]class AttackResumeCommand(TextAttackCommand): """The TextAttack attack resume recipe module: A command line parser to resume a checkpointed attack from user specifications. """
[docs] def run(self, args): checkpoint = self._parse_checkpoint_from_args(args) assert isinstance(checkpoint.attack_args, CommandLineAttackArgs), ( f"Expect `attack_args` to be of type `textattack.args.CommandLineAttackArgs`, but got type `{type(checkpoint.attack_args)}`. " f"If saved `attack_args` is not of type `textattack.args.CommandLineAttackArgs`, cannot resume attack from command line." ) # merge/update arguments checkpoint.attack_args.parallel = args.parallel if args.checkpoint_dir: checkpoint.attack_args.checkpoint_dir = args.checkpoint_dir if args.checkpoint_interval: checkpoint.attack_args.checkpoint_interval = args.checkpoint_interval model_wrapper = ModelArgs._create_model_from_args( checkpoint.attack_args.attack_args ) attack = CommandLineAttackArgs._create_attack_from_args( checkpoint.attack_args, model_wrapper ) dataset = DatasetArgs.parse_dataset_from_args(checkpoint.attack_args) attacker = Attacker.from_checkpoint(attack, dataset, checkpoint) attacker.attack_dataset()
def _parse_checkpoint_from_args(self, args): file_name = os.path.basename(args.checkpoint_file) if file_name.lower() == "latest": dir_path = os.path.dirname(args.checkpoint_file) dir_path = dir_path if dir_path else "." chkpt_file_names = [ f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt") ] assert chkpt_file_names, "AttackCheckpoint directory is empty" timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names] latest_file = str(max(timestamps)) + ".ta.chkpt" checkpoint_path = os.path.join(dir_path, latest_file) else: checkpoint_path = args.checkpoint_file checkpoint = textattack.shared.AttackCheckpoint.load(checkpoint_path) return checkpoint
[docs] @staticmethod def register_subcommand(main_parser: ArgumentParser): resume_parser = main_parser.add_parser( "attack-resume", help="resume a checkpointed attack", formatter_class=ArgumentDefaultsHelpFormatter, ) # Parser for parsing args for resume resume_parser.add_argument( "--checkpoint-file", "-f", type=str, required=True, help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,' "recover latest checkpoint from either current path or specified directory.", ) resume_parser.add_argument( "--checkpoint-dir", "-d", required=False, type=str, default=None, help="The directory to save checkpoint files. If not set, use directory from recovered arguments.", ) resume_parser.add_argument( "--checkpoint-interval", "-i", required=False, type=int, help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.", ) resume_parser.add_argument( "--parallel", action="store_true", default=False, help="Run attack using multiple GPUs.", ) resume_parser.set_defaults(func=AttackResumeCommand())