Training API Reference

Trainer

The Trainer class provides an API for adversarial training with features builtin for standard use cases. It is designed to be similar to the Trainer class provided by 🤗 Transformers library. Custom behaviors can be added by subclassing the class and overriding these methods:

  • training_step(): Peform a single training step. Override this for custom forward pass or custom loss.

  • evaluate_step(): Peform a single evaluation step. Override this for custom foward pass.

  • get_train_dataloader(): Creates the PyTorch DataLoader for training. Override this for custom batch setup.

  • get_eval_dataloader(): Creates the PyTorch DataLoader for evaluation. Override this for custom batch setup.

  • get_optimizer_and_scheduler(): Creates the optimizer and scheduler for training. Override this for custom optimizer and scheduler.

The pseudocode for how training is done:

train_preds = []
train_targets = []
for batch in train_dataloader:
   loss, preds, targets = training_step(model, tokenizer, batch)
   train_preds.append(preds)
   train_targets.append(targets)

   # clear gradients
   optimizer.zero_grad()

   # backward
   loss.backward()

   # update parameters
   optimizer.step()
   if scheduler:
      scheduler.step()

# Calculate training accuracy using `train_preds` and `train_targets`

eval_preds = []
eval_targets = []
for batch in eval_dataloader:
   loss, preds, targets = training_step(model, tokenizer, batch)
   eval_preds.append(preds)
   eval_targets.append(targets)

# Calculate eval accuracy using `eval_preds` and `eval_targets`
class textattack.Trainer(model_wrapper, task_type='classification', attack=None, train_dataset=None, eval_dataset=None, training_args=None)[source]

Trainer is training and eval loop for adversarial training.

It is designed to work with PyTorch and Transformers models.

Parameters:
  • model_wrapper (ModelWrapper) – Model wrapper containing both the model and the tokenizer.

  • task_type (str, optional, defaults to "classification") – The task that the model is trained to perform. Currently, Trainer supports two tasks: (1) "classification", (2) "regression".

  • attack (Attack) – Attack used to generate adversarial examples for training.

  • train_dataset (Dataset) – Dataset for training.

  • eval_dataset (Dataset) – Dataset for evaluation

  • training_args (TrainingArgs) – Arguments for training.

Example:

>>> import textattack
>>> import transformers

>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

>>> # We only use DeepWordBugGao2018 to demonstration purposes.
>>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train")
>>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")

>>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4).
>>> training_args = textattack.TrainingArgs(
...     num_epochs=3,
...     num_clean_epochs=1,
...     num_train_adv_examples=1000,
...     learning_rate=5e-5,
...     per_device_train_batch_size=8,
...     gradient_accumulation_steps=4,
...     log_to_tb=True,
... )

>>> trainer = textattack.Trainer(
...     model_wrapper,
...     "classification",
...     attack,
...     train_dataset,
...     eval_dataset,
...     training_args
... )
>>> trainer.train()

Note

When using Trainer with parallel=True in TrainingArgs, make sure to protect the “entry point” of the program by using if __name__ == '__main__':. If not, each worker process used for generating adversarial examples will execute the training code again.

evaluate()[source]

Evaluate the model on given evaluation dataset.

evaluate_step(model, tokenizer, batch)[source]

Perform a single evaluation step on a batch of inputs.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • tokenizer – Tokenizer used to tokenize input text.

  • batch (tuple[list[str], torch.Tensor]) –

    By default, this will be a tuple of input texts and target tensors.

    Note

    If you override the get_eval_dataloader() method, then shape/type of batch will depend on how you created your batch.

Returns:

tuple[torch.Tensor, torch.Tensor] where

  • preds: torch.FloatTensor of model’s prediction for the batch.

  • targets: torch.Tensor of model’s targets (e.g. labels, target values).

get_eval_dataloader(dataset, batch_size)[source]

Returns the torch.utils.data.DataLoader for evaluation.

Parameters:
  • dataset (Dataset) – Dataset to use for evaluation.

  • batch_size (int) – Batch size for evaluation.

Returns:

torch.utils.data.DataLoader

get_optimizer_and_scheduler(model, num_training_steps)[source]

Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return None for scheduler.

Parameters:
  • model (torch.nn.Module) – Model to be trained. Pass its parameters to optimizer for training.

  • num_training_steps (int) – Number of total training steps.

Returns:

Tuple of optimizer and scheduler tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]

get_train_dataloader(dataset, adv_dataset, batch_size)[source]

Returns the torch.utils.data.DataLoader for training.

Parameters:
  • dataset (Dataset) – Original training dataset.

  • adv_dataset (Dataset) – Adversarial examples generated from the original training dataset. None if no adversarial attack takes place.

  • batch_size (int) – Batch size for training.

Returns:

torch.utils.data.DataLoader

train()[source]

Train the model on given training dataset.

training_step(model, tokenizer, batch)[source]

Perform a single training step on a batch of inputs.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • tokenizer – Tokenizer used to tokenize input text.

  • batch (tuple[list[str], torch.Tensor, torch.Tensor]) –

    By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.

    Note

    If you override the get_train_dataloader() method, then shape/type of batch will depend on how you created your batch.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor] where

  • loss: torch.FloatTensor of shape 1 containing the loss.

  • preds: torch.FloatTensor of model’s prediction for the batch.

  • targets: torch.Tensor of model’s targets (e.g. labels, target values).

TrainingArgs

Training arguments to be passed to Trainer class.

class textattack.TrainingArgs(num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int | None = None, learning_rate: float = 5e-05, num_warmup_steps: int | float = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: int | float = -1, query_budget_train: int | None = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int | None = None, checkpoint_interval_epochs: int | None = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str | None = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]

Arguments for Trainer class that is used for adversarial training.

Parameters:
  • num_epochs (int, optional, defaults to 3) – Total number of epochs for training.

  • num_clean_epochs (int, optional, defaults to 1) – Number of epochs to train on just the original training dataset before adversarial training.

  • attack_epoch_interval (int, optional, defaults to 1) – Generate a new adversarial training set every N epochs.

  • early_stopping_epochs (int, optional, defaults to None) – Number of epochs validation must increase before stopping early (None for no early stopping).

  • learning_rate (float, optional, defaults to 5e-5) – Learning rate for optimizer.

  • num_warmup_steps (int or float, optional, defaults to 500) – The number of steps for the warmup phase of linear scheduler. If num_warmup_steps is a float between 0 and 1, the number of warmup steps will be math.ceil(num_training_steps * num_warmup_steps).

  • weight_decay (float, optional, defaults to 0.01) – Weight decay (L2 penalty).

  • per_device_train_batch_size (int, optional, defaults to 8) – The batch size per GPU/CPU for training.

  • per_device_eval_batch_size (int, optional, defaults to 32) – The batch size per GPU/CPU for evaluation.

  • gradient_accumulation_steps (int, optional, defaults to 1) – Number of updates steps to accumulate the gradients before performing a backward/update pass.

  • random_seed (int, optional, defaults to 786) – Random seed for reproducibility.

  • parallel (bool, optional, defaults to False) – If True, train using multiple GPUs using torch.DataParallel.

  • load_best_model_at_end (bool, optional, defaults to False) – If True, keep track of the best model across training and load it at the end.

  • alpha (float, optional, defaults to 1.0) – The weight for adversarial loss.

  • num_train_adv_examples (int or float, optional, defaults to -1) – The number of samples to successfully attack when generating adversarial training set before start of every epoch. If num_train_adv_examples is a float between 0 and 1, the number of adversarial examples generated is fraction of the original training set.

  • query_budget_train (int, optional, defaults to None) – The max query budget to use when generating adversarial training set. None means infinite query budget.

  • attack_num_workers_per_device (int, defaults to optional, 1) – Number of worker processes to run per device for attack. Same as num_workers_per_device argument for AttackArgs.

  • output_dir (str, optional) – Directory to output training logs and checkpoints. Defaults to /outputs/%Y-%m-%d-%H-%M-%S-%f format.

  • checkpoint_interval_steps (int, optional, defaults to None) – If set, save model checkpoint after every N updates to the model.

  • checkpoint_interval_epochs (int, optional, defaults to None) – If set, save model checkpoint after every N epochs.

  • save_last (bool, optional, defaults to True) – If True, save the model at end of training. Can be used with load_best_model_at_end to save the best model at the end.

  • log_to_tb (bool, optional, defaults to False) – If True, log to Tensorboard.

  • tb_log_dir (str, optional, defaults to "./runs") – Path of Tensorboard log directory.

  • log_to_wandb (bool, optional, defaults to False) – If True, log to Wandb.

  • wandb_project (str, optional, defaults to "textattack") – Name of Wandb project for logging.

  • logging_interval_step (int, optional, defaults to 1) – Log to Tensorboard/Wandb every N training steps.