textattack package

Welcome to the API references for TextAttack!

What is TextAttack?

TextAttack is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.

TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It’s also useful for NLP model training, adversarial training, and data augmentation.

TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.

Attack Class

class textattack.attack.Attack(goal_function: GoalFunction, constraints: List[Constraint | PreTransformationConstraint], transformation: Transformation, search_method: SearchMethod, transformation_cache_size=32768, constraint_cache_size=32768)[source]

Bases: object

An attack generates adversarial examples on text.

An attack is comprised of a goal function, constraints, transformation, and a search method. Use attack() method to attack one sample at a time.

Parameters:
  • goal_function (GoalFunction) – A function for determining how well a perturbation is doing at achieving the attack’s goal.

  • constraints (list of Constraint or PreTransformationConstraint) – A list of constraints to add to the attack, defining which perturbations are valid.

  • transformation (Transformation) – The transformation applied at each step of the attack.

  • search_method (SearchMethod) – The method for exploring the search space of possible perturbations

  • transformation_cache_size (int, optional, defaults to 2**15) – The number of items to keep in the transformations cache

  • constraint_cache_size (int, optional, defaults to 2**15) – The number of items to keep in the constraints cache

Example:

>>> import textattack
>>> import transformers

>>> # Load model, tokenizer, and model_wrapper
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

>>> # Construct our four components for `Attack`
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
>>> from textattack.constraints.semantics import WordEmbeddingDistance
>>> from textattack.transformations import WordSwapEmbedding
>>> from textattack.search_methods import GreedyWordSwapWIR

>>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
>>> constraints = [
...     RepeatModification(),
...     StopwordModification(),
...     WordEmbeddingDistance(min_cos_sim=0.9)
... ]
>>> transformation = WordSwapEmbedding(max_candidates=50)
>>> search_method = GreedyWordSwapWIR(wir_method="delete")

>>> # Construct the actual attack
>>> attack = textattack.Attack(goal_function, constraints, transformation, search_method)

>>> input_text = "I really enjoyed the new movie that came out last month."
>>> label = 1 #Positive
>>> attack_result = attack.attack(input_text, label)
attack(example, ground_truth_output)[source]

Attack a single example.

Parameters:
  • example (str, OrderedDict[str, str] or AttackedText) – Example to attack. It can be a single string or an OrderedDict where keys represent the input fields (e.g. “premise”, “hypothesis”) and the values are the actual input textx. Also accepts AttackedText that wraps around the input.

  • ground_truth_output (int, float or str) – Ground truth output of example. For classification tasks, it should be an integer representing the ground truth label. For regression tasks (e.g. STS), it should be the target value. For seq2seq tasks (e.g. translation), it should be the target string.

Returns:

AttackResult that represents the result of the attack.

clear_cache(recursive=True)[source]
cpu_()[source]

Move any torch.nn.Module models that are part of Attack to CPU.

cuda_()[source]

Move any torch.nn.Module models that are part of Attack to GPU.

filter_transformations(transformed_texts, current_text, original_text=None)[source]

Filters a list of potential transformed texts based on self.constraints Utilizes an LRU cache to attempt to avoid recomputing common transformations.

Parameters:
  • transformed_texts – A list of candidate transformed AttackedText to filter.

  • current_text – The current AttackedText on which the transformation was applied.

  • original_text – The original AttackedText from which the attack started.

get_indices_to_order(current_text, **kwargs)[source]

Applies pre_transformation_constraints to text to get all the indices that can be used to search and order.

Parameters:

current_text – The current AttackedText for which we need to find indices are eligible to be ordered.

Returns:

The length and the filtered list of indices which search methods can use to search/order.

get_transformations(current_text, original_text=None, **kwargs)[source]

Applies self.transformation to text, then filters the list of possible transformations through the applicable constraints.

Parameters:
  • current_text – The current AttackedText on which to perform the transformations.

  • original_text – The original AttackedText from which the attack started.

Returns:

A filtered list of transformations where each transformation matches the constraints

AttackArgs Class

class textattack.attack_args.AttackArgs(num_examples: int = 10, num_successful_examples: int | None = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int | None = None, checkpoint_interval: int | None = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, parallel: bool = False, num_workers_per_device: int = 1, log_to_txt: str | None = None, log_to_csv: str | None = None, log_summary_to_json: str | None = None, csv_coloring_style: str = 'file', log_to_visdom: dict | None = None, log_to_wandb: dict | None = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Dict | None = None)[source]

Bases: object

Attack arguments to be passed to Attacker.

Parameters:
  • num_examples (int, ‘optional`, defaults to 10) – The number of examples to attack. -1 for entire dataset.

  • num_successful_examples (int, optional, defaults to None) –

    The number of successful adversarial examples we want. This is different from num_examples as num_examples only cares about attacking N samples while num_successful_examples aims to keep attacking until we have N successful cases. .. note:

    If set, this argument overrides `num_examples` argument.
    

  • ( (num_examples_offset) – obj: int, optional, defaults to 0): The offset index to start at in the dataset.

  • attack_n (bool, optional, defaults to False) – Whether to run attack until total of N examples have been attacked (and not skipped).

  • shuffle (bool, optional, defaults to False) – If True, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means shuffle can now be used with checkpoint saving.

  • query_budget (int, optional, defaults to None) –

    The maximum number of model queries allowed per example attacked. If not set, we use the query budget set in the GoalFunction object (which by default is float("inf")). .. note:

    Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
    

  • checkpoint_interval (int, optional, defaults to None) – If set, checkpoint will be saved after attacking every N examples. If None is passed, no checkpoints will be saved.

  • checkpoint_dir (str, optional, defaults to "checkpoints") – The directory to save checkpoint files.

  • random_seed (int, optional, defaults to 765) – Random seed for reproducibility.

  • parallel (False, optional, defaults to False) – If True, run attack using multiple CPUs/GPUs.

  • num_workers_per_device (int, optional, defaults to 1) – Number of worker processes to run per device in parallel mode (i.e. parallel=True). For example, if you are using GPUs and num_workers_per_device=2, then 2 processes will be running in each GPU.

  • log_to_txt (str, optional, defaults to None) – If set, save attack logs as a .txt file to the directory specified by this argument. If the last part of the provided path ends with .txt extension, it is assumed to the desired path of the log file.

  • log_to_csv (str, optional, defaults to None) – If set, save attack logs as a CSV file to the directory specified by this argument. If the last part of the provided path ends with .csv extension, it is assumed to the desired path of the log file.

  • csv_coloring_style (str, optional, defaults to "file") – Method for choosing how to mark perturbed parts of the text. Options are "file", "plain", and "html". "file" wraps perturbed parts with double brackets [[ <text> ]] while "plain" does not mark the text in any way.

  • log_to_visdom (dict, optional, defaults to None) – If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to VisdomLogger. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following three keys and their corresponding values: "env", "port", "hostname".

  • log_to_wandb (dict, optional, defaults to None) – If set, WandB logger is used with the provided dictionary passed as a keyword arguments to WeightsAndBiasesLogger. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following key and its corresponding value: "project".

  • disable_stdout (bool, optional, defaults to False) – Disable displaying individual attack results to stdout.

  • silent (bool, optional, defaults to False) – Disable all logging (except for errors). This is stronger than disable_stdout.

  • enable_advance_metrics (bool, optional, defaults to False) – Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.

classmethod create_loggers_from_args(args)[source]

Creates AttackLogManager from an AttackArgs object.

attack_n: bool = False
checkpoint_dir: str = 'checkpoints'
checkpoint_interval: int = None
csv_coloring_style: str = 'file'
disable_stdout: bool = False
enable_advance_metrics: bool = False
log_summary_to_json: str = None
log_to_csv: str = None
log_to_txt: str = None
log_to_visdom: dict = None
log_to_wandb: dict = None
metrics: Dict | None = None
num_examples: int = 10
num_examples_offset: int = 0
num_successful_examples: int = None
num_workers_per_device: int = 1
parallel: bool = False
query_budget: int = None
random_seed: int = 765
shuffle: bool = False
silent: bool = False
class textattack.attack_args.CommandLineAttackArgs(model: str = None, model_from_file: str = None, model_from_huggingface: str = None, dataset_by_model: str = None, dataset_from_huggingface: str = None, dataset_from_file: str = None, dataset_split: str = None, filter_by_labels: list = None, transformation: str = 'word-swap-embedding', constraints: list = <factory>, goal_function: str = 'untargeted-classification', search_method: str = 'greedy-word-wir', attack_recipe: str = None, attack_from_file: str = None, interactive: bool = False, parallel: bool = False, model_batch_size: int = 32, model_cache_size: int = 262144, constraint_cache_size: int = 262144, num_examples: int = 10, num_successful_examples: int = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int = None, checkpoint_interval: int = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, num_workers_per_device: int = 1, log_to_txt: str = None, log_to_csv: str = None, log_summary_to_json: str = None, csv_coloring_style: str = 'file', log_to_visdom: dict = None, log_to_wandb: dict = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Union[Dict, NoneType] = None)[source]

Bases: AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs

Attacker Class

class textattack.attacker.Attacker(attack, dataset, attack_args=None)[source]

Bases: object

Class for running attacks on a dataset with specified parameters. This class uses the Attack to actually run the attacks, while also providing useful features such as parallel processing, saving/resuming from a checkpint, logging to files and stdout.

Parameters:
  • attack (Attack) – Attack used to actually carry out the attack.

  • dataset (Dataset) – Dataset to attack.

  • attack_args (AttackArgs) – Arguments for attacking the dataset. For default settings, look at the AttackArgs class.

Example:

>>> import textattack
>>> import transformers

>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")

>>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval
>>> attack_args = textattack.AttackArgs(
...     num_examples=20,
...     log_to_csv="log.csv",
...     checkpoint_interval=5,
...     checkpoint_dir="checkpoints",
...     disable_stdout=True
... )

>>> attacker = textattack.Attacker(attack, dataset, attack_args)
>>> attacker.attack_dataset()
attack_dataset()[source]

Attack the dataset.

Returns:

list[AttackResult] - List of AttackResult obtained after attacking the given dataset..

static attack_interactive(attack)[source]
classmethod from_checkpoint(attack, dataset, checkpoint)[source]

Resume attacking from a saved checkpoint. Attacker and dataset must be recovered by the user again, while attack args are loaded from the saved checkpoint.

Parameters:
  • attack (Attack) – Attack object for carrying out the attack.

  • dataset (Dataset) – Dataset to attack.

  • checkpoint (Union[str, :class:`~textattack.shared.AttackChecpoint]`) – Path of saved checkpoint or the actual saved checkpoint.

update_attack_args(**kwargs)[source]

To update any attack args, pass the new argument as keyword argument to this function.

Examples:

>>> attacker = #some instance of Attacker
>>> # To switch to parallel mode and increase checkpoint interval from 100 to 500
>>> attacker.update_attack_args(parallel=True, checkpoint_interval=500)
textattack.attacker.attack_from_queue(attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue)[source]
textattack.attacker.pytorch_multiprocessing_workaround()[source]
textattack.attacker.set_env_variables(gpu_id)[source]

AugmenterArgs Class

class textattack.augment_args.AugmenterArgs(input_csv: str, output_csv: str, input_column: str, recipe: str = 'embedding', pct_words_to_swap: float = 0.1, transformations_per_example: int = 2, random_seed: int = 42, exclude_original: bool = False, overwrite: bool = False, interactive: bool = False, fast_augment: bool = False, high_yield: bool = False, enable_advanced_metrics: bool = False)[source]

Bases: object

Arguments for performing data augmentation.

Parameters:
  • input_csv (str) – Path of input CSV file to augment.

  • output_csv (str) – Path of CSV file to output augmented data.

enable_advanced_metrics: bool = False
exclude_original: bool = False
fast_augment: bool = False
high_yield: bool = False
input_column: str
input_csv: str
interactive: bool = False
output_csv: str
overwrite: bool = False
pct_words_to_swap: float = 0.1
random_seed: int = 42
recipe: str = 'embedding'
transformations_per_example: int = 2

DatasetArgs Class

class textattack.dataset_args.DatasetArgs(dataset_by_model: str | None = None, dataset_from_huggingface: str | None = None, dataset_from_file: str | None = None, dataset_split: str | None = None, filter_by_labels: list | None = None)[source]

Bases: object

Arguments for loading dataset from command line input.

dataset_by_model: str = None
dataset_from_file: str = None
dataset_from_huggingface: str = None
dataset_split: str = None
filter_by_labels: list = None

ModelArgs Class

class textattack.model_args.ModelArgs(model: str | None = None, model_from_file: str | None = None, model_from_huggingface: str | None = None)[source]

Bases: object

Arguments for loading base/pretrained or trained models.

model: str = None
model_from_file: str = None
model_from_huggingface: str = None

Trainer Class

class textattack.trainer.Trainer(model_wrapper, task_type='classification', attack=None, train_dataset=None, eval_dataset=None, training_args=None)[source]

Bases: object

Trainer is training and eval loop for adversarial training.

It is designed to work with PyTorch and Transformers models.

Parameters:
  • model_wrapper (ModelWrapper) – Model wrapper containing both the model and the tokenizer.

  • task_type (str, optional, defaults to "classification") – The task that the model is trained to perform. Currently, Trainer supports two tasks: (1) "classification", (2) "regression".

  • attack (Attack) – Attack used to generate adversarial examples for training.

  • train_dataset (Dataset) – Dataset for training.

  • eval_dataset (Dataset) – Dataset for evaluation

  • training_args (TrainingArgs) – Arguments for training.

Example:

>>> import textattack
>>> import transformers

>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

>>> # We only use DeepWordBugGao2018 to demonstration purposes.
>>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train")
>>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")

>>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4).
>>> training_args = textattack.TrainingArgs(
...     num_epochs=3,
...     num_clean_epochs=1,
...     num_train_adv_examples=1000,
...     learning_rate=5e-5,
...     per_device_train_batch_size=8,
...     gradient_accumulation_steps=4,
...     log_to_tb=True,
... )

>>> trainer = textattack.Trainer(
...     model_wrapper,
...     "classification",
...     attack,
...     train_dataset,
...     eval_dataset,
...     training_args
... )
>>> trainer.train()

Note

When using Trainer with parallel=True in TrainingArgs, make sure to protect the “entry point” of the program by using if __name__ == '__main__':. If not, each worker process used for generating adversarial examples will execute the training code again.

evaluate()[source]

Evaluate the model on given evaluation dataset.

evaluate_step(model, tokenizer, batch)[source]

Perform a single evaluation step on a batch of inputs.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • tokenizer – Tokenizer used to tokenize input text.

  • batch (tuple[list[str], torch.Tensor]) –

    By default, this will be a tuple of input texts and target tensors.

    Note

    If you override the get_eval_dataloader() method, then shape/type of batch will depend on how you created your batch.

Returns:

tuple[torch.Tensor, torch.Tensor] where

  • preds: torch.FloatTensor of model’s prediction for the batch.

  • targets: torch.Tensor of model’s targets (e.g. labels, target values).

get_eval_dataloader(dataset, batch_size)[source]

Returns the torch.utils.data.DataLoader for evaluation.

Parameters:
  • dataset (Dataset) – Dataset to use for evaluation.

  • batch_size (int) – Batch size for evaluation.

Returns:

torch.utils.data.DataLoader

get_optimizer_and_scheduler(model, num_training_steps)[source]

Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return None for scheduler.

Parameters:
  • model (torch.nn.Module) – Model to be trained. Pass its parameters to optimizer for training.

  • num_training_steps (int) – Number of total training steps.

Returns:

Tuple of optimizer and scheduler tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]

get_train_dataloader(dataset, adv_dataset, batch_size)[source]

Returns the torch.utils.data.DataLoader for training.

Parameters:
  • dataset (Dataset) – Original training dataset.

  • adv_dataset (Dataset) – Adversarial examples generated from the original training dataset. None if no adversarial attack takes place.

  • batch_size (int) – Batch size for training.

Returns:

torch.utils.data.DataLoader

train()[source]

Train the model on given training dataset.

training_step(model, tokenizer, batch)[source]

Perform a single training step on a batch of inputs.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • tokenizer – Tokenizer used to tokenize input text.

  • batch (tuple[list[str], torch.Tensor, torch.Tensor]) –

    By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.

    Note

    If you override the get_train_dataloader() method, then shape/type of batch will depend on how you created your batch.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor] where

  • loss: torch.FloatTensor of shape 1 containing the loss.

  • preds: torch.FloatTensor of model’s prediction for the batch.

  • targets: torch.Tensor of model’s targets (e.g. labels, target values).

TrainingArgs Class

class textattack.training_args.CommandLineTrainingArgs(model_name_or_path: str, attack: str, dataset: str, task_type: str = 'classification', model_max_length: int = None, model_num_labels: int = None, dataset_train_split: str = None, dataset_eval_split: str = None, filter_train_by_labels: list = None, filter_eval_by_labels: list = None, num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int = None, learning_rate: float = 5e-05, num_warmup_steps: Union[int, float] = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: Union[int, float] = -1, query_budget_train: int = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int = None, checkpoint_interval_epochs: int = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]

Bases: TrainingArgs, _CommandLineTrainingArgs

output_dir: str
class textattack.training_args.TrainingArgs(num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int | None = None, learning_rate: float = 5e-05, num_warmup_steps: int | float = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: int | float = -1, query_budget_train: int | None = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int | None = None, checkpoint_interval_epochs: int | None = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str | None = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]

Bases: object

Arguments for Trainer class that is used for adversarial training.

Parameters:
  • num_epochs (int, optional, defaults to 3) – Total number of epochs for training.

  • num_clean_epochs (int, optional, defaults to 1) – Number of epochs to train on just the original training dataset before adversarial training.

  • attack_epoch_interval (int, optional, defaults to 1) – Generate a new adversarial training set every N epochs.

  • early_stopping_epochs (int, optional, defaults to None) – Number of epochs validation must increase before stopping early (None for no early stopping).

  • learning_rate (float, optional, defaults to 5e-5) – Learning rate for optimizer.

  • num_warmup_steps (int or float, optional, defaults to 500) – The number of steps for the warmup phase of linear scheduler. If num_warmup_steps is a float between 0 and 1, the number of warmup steps will be math.ceil(num_training_steps * num_warmup_steps).

  • weight_decay (float, optional, defaults to 0.01) – Weight decay (L2 penalty).

  • per_device_train_batch_size (int, optional, defaults to 8) – The batch size per GPU/CPU for training.

  • per_device_eval_batch_size (int, optional, defaults to 32) – The batch size per GPU/CPU for evaluation.

  • gradient_accumulation_steps (int, optional, defaults to 1) – Number of updates steps to accumulate the gradients before performing a backward/update pass.

  • random_seed (int, optional, defaults to 786) – Random seed for reproducibility.

  • parallel (bool, optional, defaults to False) – If True, train using multiple GPUs using torch.DataParallel.

  • load_best_model_at_end (bool, optional, defaults to False) – If True, keep track of the best model across training and load it at the end.

  • alpha (float, optional, defaults to 1.0) – The weight for adversarial loss.

  • num_train_adv_examples (int or float, optional, defaults to -1) – The number of samples to successfully attack when generating adversarial training set before start of every epoch. If num_train_adv_examples is a float between 0 and 1, the number of adversarial examples generated is fraction of the original training set.

  • query_budget_train (int, optional, defaults to None) – The max query budget to use when generating adversarial training set. None means infinite query budget.

  • attack_num_workers_per_device (int, defaults to optional, 1) – Number of worker processes to run per device for attack. Same as num_workers_per_device argument for AttackArgs.

  • output_dir (str, optional) – Directory to output training logs and checkpoints. Defaults to /outputs/%Y-%m-%d-%H-%M-%S-%f format.

  • checkpoint_interval_steps (int, optional, defaults to None) – If set, save model checkpoint after every N updates to the model.

  • checkpoint_interval_epochs (int, optional, defaults to None) – If set, save model checkpoint after every N epochs.

  • save_last (bool, optional, defaults to True) – If True, save the model at end of training. Can be used with load_best_model_at_end to save the best model at the end.

  • log_to_tb (bool, optional, defaults to False) – If True, log to Tensorboard.

  • tb_log_dir (str, optional, defaults to "./runs") – Path of Tensorboard log directory.

  • log_to_wandb (bool, optional, defaults to False) – If True, log to Wandb.

  • wandb_project (str, optional, defaults to "textattack") – Name of Wandb project for logging.

  • logging_interval_step (int, optional, defaults to 1) – Log to Tensorboard/Wandb every N training steps.

alpha: float = 1.0
attack_epoch_interval: int = 1
attack_num_workers_per_device: int = 1
checkpoint_interval_epochs: int = None
checkpoint_interval_steps: int = None
early_stopping_epochs: int = None
gradient_accumulation_steps: int = 1
learning_rate: float = 5e-05
load_best_model_at_end: bool = False
log_to_tb: bool = False
log_to_wandb: bool = False
logging_interval_step: int = 1
num_clean_epochs: int = 1
num_epochs: int = 3
num_train_adv_examples: int | float = -1
num_warmup_steps: int | float = 500
output_dir: str
parallel: bool = False
per_device_eval_batch_size: int = 32
per_device_train_batch_size: int = 8
query_budget_train: int = None
random_seed: int = 786
save_last: bool = True
tb_log_dir: str = None
wandb_project: str = 'textattack'
weight_decay: float = 0.01
textattack.training_args.default_output_dir()[source]