textattack package
Welcome to the API references for TextAttack!
What is TextAttack?
TextAttack is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It’s also useful for NLP model training, adversarial training, and data augmentation.
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
- textattack.attack_recipes package
- Attack Recipes Package:
- A2T (A2T: Attack for Adversarial Training Recipe)
- Attack Recipe Class
- BAE (BAE: BERT-Based Adversarial Examples)
- BERT-Attack:
- CheckList:
- CLARE Recipe
- DeepWordBug
- Faster Alzantot Genetic Algorithm
- Alzantot Genetic Algorithm
- HotFlip
- Improved Genetic Algorithm
- Input Reduction
- Kuleshov2017
- MORPHEUS2020
- Pruthi2019: Combating with Robust Word Recognition
- Particle Swarm Optimization
- PWWS
- Seq2Sick
- TextBugger
- TextFooler (Is BERT Really Robust?)
- textattack.attack_results package
- textattack.augmentation package
- textattack.commands package
- textattack.constraints package
- Constraints
- textattack.constraints.grammaticality package
- Grammaticality:
- textattack.constraints.grammaticality.language_models package
- CoLA for Grammaticality
- LanguageTool Grammar Checker
- Part of Speech Constraint
- textattack.constraints.overlap package
- textattack.constraints.pre_transformation package
- textattack.constraints.semantics package
- Semantic Constraints
- textattack.constraints.semantics.sentence_encoders package
- BERT Score
- Word Embedding Distance
- TextAttack Constraint Class
- Pre-Transformation Constraint Class
- textattack.datasets package
- textattack.goal_function_results package
- textattack.goal_functions package
- textattack.loggers package
- textattack.metrics package
- textattack.models package
- textattack.search_methods package
- Search Methods
- Reimplementation of search method from Generating Natural Language Adversarial Examples
- Beam Search
- Genetic Algorithm Word Swap
- Greedy Search
- Greedy Word Swap with Word Importance Ranking
- Reimplementation of search method from Xiaosen Wang, Hao Jin, Kun He (2019).
- Particle Swarm Optimization
- Population based Search
- Search Method Abstract Class
- textattack.shared package
- textattack.transformations package
- Transformations
- textattack.transformations.sentence_transformations package
- textattack.transformations.word_insertions package
- textattack.transformations.word_merges package
- textattack.transformations.word_swaps package
- word_swaps package
- Word Swap
- Word Swap by Changing Location
- Word Swap by Changing Name
- Word Swap by Changing Number
- Word Swap by Contraction
- Word Swap by Embedding
- Word Swap by Extension
- Word Swap by Gradient
- Word Swap by Homoglyph
- Word Swap by OpenHowNet
- Word Swap by inflections
- Word Swap by BERT-Masked LM.
- Word Swap by Neighboring Character Swap
- Word Swap by swaps characters with QWERTY adjacent keys
- Word Swap by Random Character Deletion
- Word Swap by Random Character Insertion
- Word Swap by Random Character Substitution
- Word Swap by swapping synonyms in WordNet
- Composite Transformation
- Transformation Abstract Class
- word deletion Transformation
- Word Swap Transformation by swapping the order of words
Attack Class
- class textattack.attack.Attack(goal_function: textattack.goal_functions.goal_function.GoalFunction, constraints: List[Union[textattack.constraints.constraint.Constraint, textattack.constraints.pre_transformation_constraint.PreTransformationConstraint]], transformation: textattack.transformations.transformation.Transformation, search_method: textattack.search_methods.search_method.SearchMethod, transformation_cache_size=32768, constraint_cache_size=32768)[source]
Bases:
object
An attack generates adversarial examples on text.
An attack is comprised of a goal function, constraints, transformation, and a search method. Use
attack()
method to attack one sample at a time.- Parameters
goal_function (
GoalFunction
) – A function for determining how well a perturbation is doing at achieving the attack’s goal.constraints (list of
Constraint
orPreTransformationConstraint
) – A list of constraints to add to the attack, defining which perturbations are valid.transformation (
Transformation
) – The transformation applied at each step of the attack.search_method (
SearchMethod
) – The method for exploring the search space of possible perturbationstransformation_cache_size (
int
, optional, defaults to2**15
) – The number of items to keep in the transformations cacheconstraint_cache_size (
int
, optional, defaults to2**15
) – The number of items to keep in the constraints cache
Example:
>>> import textattack >>> import transformers >>> # Load model, tokenizer, and model_wrapper >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # Construct our four components for `Attack` >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification >>> from textattack.constraints.semantics import WordEmbeddingDistance >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) >>> constraints = [ ... RepeatModification(), ... StopwordModification() ... WordEmbeddingDistance(min_cos_sim=0.9) ... ] >>> transformation = WordSwapEmbedding(max_candidates=50) >>> search_method = GreedyWordSwapWIR(wir_method="delete") >>> # Construct the actual attack >>> attack = Attack(goal_function, constraints, transformation, search_method) >>> input_text = "I really enjoyed the new movie that came out last month." >>> label = 1 #Positive >>> attack_result = attack.attack(input_text, label)
- attack(example, ground_truth_output)[source]
Attack a single example.
- Parameters
example (
str
,OrderedDict[str, str]
orAttackedText
) – Example to attack. It can be a single string or an OrderedDict where keys represent the input fields (e.g. “premise”, “hypothesis”) and the values are the actual input textx. Also acceptsAttackedText
that wraps around the input.ground_truth_output (
int
,float
orstr
) – Ground truth output of example. For classification tasks, it should be an integer representing the ground truth label. For regression tasks (e.g. STS), it should be the target value. For seq2seq tasks (e.g. translation), it should be the target string.
- Returns
AttackResult
that represents the result of the attack.
- filter_transformations(transformed_texts, current_text, original_text=None)[source]
Filters a list of potential transformed texts based on
self.constraints
Utilizes an LRU cache to attempt to avoid recomputing common transformations.- Parameters
transformed_texts – A list of candidate transformed
AttackedText
to filter.current_text – The current
AttackedText
on which the transformation was applied.original_text – The original
AttackedText
from which the attack started.
- get_transformations(current_text, original_text=None, **kwargs)[source]
Applies
self.transformation
totext
, then filters the list of possible transformations through the applicable constraints.- Parameters
current_text – The current
AttackedText
on which to perform the transformations.original_text – The original
AttackedText
from which the attack started.
- Returns
A filtered list of transformations where each transformation matches the constraints
AttackArgs Class
- class textattack.attack_args.AttackArgs(num_examples: int = 10, num_successful_examples: Optional[int] = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: Optional[int] = None, checkpoint_interval: Optional[int] = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, parallel: bool = False, num_workers_per_device: int = 1, log_to_txt: Optional[str] = None, log_to_csv: Optional[str] = None, csv_coloring_style: str = 'file', log_to_visdom: Optional[dict] = None, log_to_wandb: Optional[dict] = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False)[source]
Bases:
object
Attack arguments to be passed to
Attacker
.- Parameters
num_examples (
int
, ‘optional`, defaults to10
) – The number of examples to attack.-1
for entire dataset.num_successful_examples (
int
, optional, defaults toNone
) –The number of successful adversarial examples we want. This is different from
num_examples
asnum_examples
only cares about attacking N samples whilenum_successful_examples
aims to keep attacking until we have N successful cases.Note
If set, this argument overrides num_examples argument.
( (num_examples_offset) – obj: int, optional, defaults to
0
): The offset index to start at in the dataset.attack_n (
bool
, optional, defaults toFalse
) – Whether to run attack until total of N examples have been attacked (and not skipped).shuffle (
bool
, optional, defaults toFalse
) – IfTrue
, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling the dataset internally and opts for shuffling the list of indices of examples we want to attack. This meansshuffle
can now be used with checkpoint saving.query_budget (
int
, optional, defaults toNone
) –The maximum number of model queries allowed per example attacked. If not set, we use the query budget set in the
GoalFunction
object (which by default isfloat("inf")
).Note
Setting this overwrites the query budget set in
GoalFunction
object.checkpoint_interval (
int
, optional, defaults toNone
) – If set, checkpoint will be saved after attacking every N examples. IfNone
is passed, no checkpoints will be saved.checkpoint_dir (
str
, optional, defaults to"checkpoints"
) – The directory to save checkpoint files.random_seed (
int
, optional, defaults to765
) – Random seed for reproducibility.parallel (
False
, optional, defaults toFalse
) – IfTrue
, run attack using multiple CPUs/GPUs.num_workers_per_device (
int
, optional, defaults to1
) – Number of worker processes to run per device in parallel mode (i.e.parallel=True
). For example, if you are using GPUs andnum_workers_per_device=2
, then 2 processes will be running in each GPU.log_to_txt (
str
, optional, defaults toNone
) – If set, save attack logs as a .txt file to the directory specified by this argument. If the last part of the provided path ends with .txt extension, it is assumed to the desired path of the log file.log_to_csv (
str
, optional, defaults toNone
) – If set, save attack logs as a CSV file to the directory specified by this argument. If the last part of the provided path ends with .csv extension, it is assumed to the desired path of the log file.csv_coloring_style (
str
, optional, defaults to"file"
) – Method for choosing how to mark perturbed parts of the text. Options are"file"
,"plain"
, and"html"
."file"
wraps perturbed parts with double brackets[[ <text> ]]
while"plain"
does not mark the text in any way.log_to_visdom (
dict
, optional, defaults toNone
) – If set, Visdom logger is used with the provided dictionary passed as a keyword arguments toVisdomLogger
. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following three keys and their corresponding values:"env", "port", "hostname"
.log_to_wandb (
dict
, optional, defaults toNone
) – If set, WandB logger is used with the provided dictionary passed as a keyword arguments toWeightsAndBiasesLogger
. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following key and its corresponding value:"project"
.disable_stdout (
bool
, optional, defaults toFalse
) – Disable displaying individual attack results to stdout.silent (
bool
, optional, defaults toFalse
) – Disable all logging (except for errors). This is stronger thandisable_stdout
.enable_advance_metrics (
bool
, optional, defaults toFalse
) – Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
- attack_n: bool = False
- checkpoint_dir: str = 'checkpoints'
- checkpoint_interval: int = None
- csv_coloring_style: str = 'file'
- disable_stdout: bool = False
- enable_advance_metrics: bool = False
- log_to_csv: str = None
- log_to_txt: str = None
- log_to_visdom: dict = None
- log_to_wandb: dict = None
- num_examples: int = 10
- num_examples_offset: int = 0
- num_successful_examples: int = None
- num_workers_per_device: int = 1
- parallel: bool = False
- query_budget: int = None
- random_seed: int = 765
- shuffle: bool = False
- silent: bool = False
- class textattack.attack_args.CommandLineAttackArgs(model: str = None, model_from_file: str = None, model_from_huggingface: str = None, dataset_by_model: str = None, dataset_from_huggingface: str = None, dataset_from_file: str = None, dataset_split: str = None, filter_by_labels: list = None, transformation: str = 'word-swap-embedding', constraints: list = <factory>, goal_function: str = 'untargeted-classification', search_method: str = 'greedy-word-wir', attack_recipe: str = None, attack_from_file: str = None, interactive: bool = False, parallel: bool = False, model_batch_size: int = 32, model_cache_size: int = 262144, constraint_cache_size: int = 262144, num_examples: int = 10, num_successful_examples: int = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int = None, checkpoint_interval: int = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, num_workers_per_device: int = 1, log_to_txt: str = None, log_to_csv: str = None, csv_coloring_style: str = 'file', log_to_visdom: dict = None, log_to_wandb: dict = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False)[source]
Bases:
textattack.attack_args.AttackArgs
,textattack.attack_args._CommandLineAttackArgs
,textattack.dataset_args.DatasetArgs
,textattack.model_args.ModelArgs
Attacker Class
- class textattack.attacker.Attacker(attack, dataset, attack_args=None)[source]
Bases:
object
Class for running attacks on a dataset with specified parameters. This class uses the
Attack
to actually run the attacks, while also providing useful features such as parallel processing, saving/resuming from a checkpint, logging to files and stdout.- Parameters
attack (
Attack
) –Attack
used to actually carry out the attack.dataset (
Dataset
) – Dataset to attack.attack_args (
AttackArgs
) – Arguments for attacking the dataset. For default settings, look at the AttackArgs class.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) >>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval >>> attack_args = textattack.AttackArgs( ... num_examples=20, ... log_to_csv="log.csv", ... checkpoint_interval=5, ... checkpoint_dir="checkpoints", ... disable_stdout=True ... ) >>> attacker = textattack.Attacker(attack, dataset, attack_args) >>> attacker.attack_dataset()
- attack_dataset()[source]
Attack the dataset.
- Returns
list[AttackResult]
- List ofAttackResult
obtained after attacking the given dataset..
- classmethod from_checkpoint(attack, dataset, checkpoint)[source]
Resume attacking from a saved checkpoint. Attacker and dataset must be recovered by the user again, while attack args are loaded from the saved checkpoint.
- update_attack_args(**kwargs)[source]
To update any attack args, pass the new argument as keyword argument to this function.
Examples:
>>> attacker = #some instance of Attacker >>> # To switch to parallel mode and increase checkpoint interval from 100 to 500 >>> attacker.update_attack_args(parallel=True, checkpoint_interval=500)
AugmenterArgs Class
- class textattack.augment_args.AugmenterArgs(input_csv: str, output_csv: str, input_column: str, recipe: str = 'embedding', pct_words_to_swap: float = 0.1, transformations_per_example: int = 2, random_seed: int = 42, exclude_original: bool = False, overwrite: bool = False, interactive: bool = False, fast_augment: bool = False, high_yield: bool = False, enable_advanced_metrics: bool = False)[source]
Bases:
object
Arguments for performing data augmentation.
- Parameters
input_csv (str) – Path of input CSV file to augment.
output_csv (str) – Path of CSV file to output augmented data.
- enable_advanced_metrics: bool = False
- exclude_original: bool = False
- fast_augment: bool = False
- high_yield: bool = False
- input_column: str
- input_csv: str
- interactive: bool = False
- output_csv: str
- overwrite: bool = False
- pct_words_to_swap: float = 0.1
- random_seed: int = 42
- recipe: str = 'embedding'
- transformations_per_example: int = 2
DatasetArgs Class
- class textattack.dataset_args.DatasetArgs(dataset_by_model: Optional[str] = None, dataset_from_huggingface: Optional[str] = None, dataset_from_file: Optional[str] = None, dataset_split: Optional[str] = None, filter_by_labels: Optional[list] = None)[source]
Bases:
object
Arguments for loading dataset from command line input.
- dataset_by_model: str = None
- dataset_from_file: str = None
- dataset_from_huggingface: str = None
- dataset_split: str = None
- filter_by_labels: list = None
ModelArgs Class
- class textattack.model_args.ModelArgs(model: Optional[str] = None, model_from_file: Optional[str] = None, model_from_huggingface: Optional[str] = None)[source]
Bases:
object
Arguments for loading base/pretrained or trained models.
- model: str = None
- model_from_file: str = None
- model_from_huggingface: str = None
Trainer Class
- class textattack.trainer.Trainer(model_wrapper, task_type='classification', attack=None, train_dataset=None, eval_dataset=None, training_args=None)[source]
Bases:
object
Trainer is training and eval loop for adversarial training.
It is designed to work with PyTorch and Transformers models.
- Parameters
model_wrapper (
ModelWrapper
) – Model wrapper containing both the model and the tokenizer.task_type (
str
, optional, defaults to"classification"
) – The task that the model is trained to perform. Currently,Trainer
supports two tasks: (1)"classification"
, (2)"regression"
.attack (
Attack
) –Attack
used to generate adversarial examples for training.train_dataset (
Dataset
) – Dataset for training.eval_dataset (
Dataset
) – Dataset for evaluationtraining_args (
TrainingArgs
) – Arguments for training.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # We only use DeepWordBugGao2018 to demonstration purposes. >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). >>> training_args = textattack.TrainingArgs( ... num_epochs=3, ... num_clean_epochs=1, ... num_train_adv_examples=1000, ... learning_rate=5e-5, ... per_device_train_batch_size=8, ... gradient_accumulation_steps=4, ... log_to_tb=True, ... ) >>> trainer = textattack.Trainer( ... model_wrapper, ... "classification", ... attack, ... train_dataset, ... eval_dataset, ... training_args ... ) >>> trainer.train()
Note
When using
Trainer
with parallel=True inTrainingArgs
, make sure to protect the “entry point” of the program by usingif __name__ == '__main__':
. If not, each worker process used for generating adversarial examples will execute the training code again.- evaluate_step(model, tokenizer, batch)[source]
Perform a single evaluation step on a batch of inputs.
- Parameters
model (
torch.nn.Module
) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor]
) –By default, this will be a tuple of input texts and target tensors.
Note
If you override the
get_eval_dataloader()
method, then shape/type ofbatch
will depend on how you created your batch.
- Returns
tuple[torch.Tensor, torch.Tensor]
wherepreds:
torch.FloatTensor
of model’s prediction for the batch.targets:
torch.Tensor
of model’s targets (e.g. labels, target values).
- get_eval_dataloader(dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoader
for evaluation.- Parameters
dataset (
Dataset
) – Dataset to use for evaluation.batch_size (
int
) – Batch size for evaluation.
- Returns
torch.utils.data.DataLoader
- get_optimizer_and_scheduler(model, num_training_steps)[source]
Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return
None
for scheduler.- Parameters
model (
torch.nn.Module
) – Model to be trained. Pass its parameters to optimizer for training.num_training_steps (
int
) – Number of total training steps.
- Returns
Tuple of optimizer and scheduler
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]
- get_train_dataloader(dataset, adv_dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoader
for training.
- training_step(model, tokenizer, batch)[source]
Perform a single training step on a batch of inputs.
- Parameters
model (
torch.nn.Module
) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor, torch.Tensor]
) –By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
Note
If you override the
get_train_dataloader()
method, then shape/type ofbatch
will depend on how you created your batch.
- Returns
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
whereloss:
torch.FloatTensor
of shape 1 containing the loss.preds:
torch.FloatTensor
of model’s prediction for the batch.targets:
torch.Tensor
of model’s targets (e.g. labels, target values).
TrainingArgs Class
- class textattack.training_args.CommandLineTrainingArgs(model_name_or_path: str, attack: str, dataset: str, task_type: str = 'classification', model_max_length: int = None, model_num_labels: int = None, dataset_train_split: str = None, dataset_eval_split: str = None, filter_train_by_labels: list = None, filter_eval_by_labels: list = None, num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int = None, learning_rate: float = 5e-05, num_warmup_steps: Union[int, float] = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: Union[int, float] = -1, query_budget_train: int = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int = None, checkpoint_interval_epochs: int = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
textattack.training_args.TrainingArgs
,textattack.training_args._CommandLineTrainingArgs
- output_dir: str
- class textattack.training_args.TrainingArgs(num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: Optional[int] = None, learning_rate: float = 5e-05, num_warmup_steps: Union[int, float] = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: Union[int, float] = -1, query_budget_train: Optional[int] = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: Optional[int] = None, checkpoint_interval_epochs: Optional[int] = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: Optional[str] = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
object
Arguments for
Trainer
class that is used for adversarial training.- Parameters
num_epochs (
int
, optional, defaults to3
) – Total number of epochs for training.num_clean_epochs (
int
, optional, defaults to1
) – Number of epochs to train on just the original training dataset before adversarial training.attack_epoch_interval (
int
, optional, defaults to1
) – Generate a new adversarial training set every N epochs.early_stopping_epochs (
int
, optional, defaults toNone
) – Number of epochs validation must increase before stopping early (None
for no early stopping).learning_rate (
float
, optional, defaults to5e-5
) – Learning rate for optimizer.num_warmup_steps (
int
orfloat
, optional, defaults to500
) – The number of steps for the warmup phase of linear scheduler. Ifnum_warmup_steps
is afloat
between 0 and 1, the number of warmup steps will bemath.ceil(num_training_steps * num_warmup_steps)
.weight_decay (
float
, optional, defaults to0.01
) – Weight decay (L2 penalty).per_device_train_batch_size (
int
, optional, defaults to8
) – The batch size per GPU/CPU for training.per_device_eval_batch_size (
int
, optional, defaults to32
) – The batch size per GPU/CPU for evaluation.gradient_accumulation_steps (
int
, optional, defaults to1
) – Number of updates steps to accumulate the gradients before performing a backward/update pass.random_seed (
int
, optional, defaults to786
) – Random seed for reproducibility.parallel (
bool
, optional, defaults toFalse
) – IfTrue
, train using multiple GPUs usingtorch.DataParallel
.load_best_model_at_end (
bool
, optional, defaults toFalse
) – IfTrue
, keep track of the best model across training and load it at the end.alpha (
float
, optional, defaults to1.0
) – The weight for adversarial loss.num_train_adv_examples (
int
orfloat
, optional, defaults to-1
) – The number of samples to successfully attack when generating adversarial training set before start of every epoch. Ifnum_train_adv_examples
is afloat
between 0 and 1, the number of adversarial examples generated is fraction of the original training set.query_budget_train (
int
, optional, defaults toNone
) – The max query budget to use when generating adversarial training set.None
means infinite query budget.attack_num_workers_per_device (
int
, defaults to optional,1
) – Number of worker processes to run per device for attack. Same asnum_workers_per_device
argument forAttackArgs
.output_dir (
str
, optional) – Directory to output training logs and checkpoints. Defaults to/outputs/%Y-%m-%d-%H-%M-%S-%f
format.checkpoint_interval_steps (
int
, optional, defaults toNone
) – If set, save model checkpoint after every N updates to the model.checkpoint_interval_epochs (
int
, optional, defaults toNone
) – If set, save model checkpoint after every N epochs.save_last (
bool
, optional, defaults toTrue
) – IfTrue
, save the model at end of training. Can be used withload_best_model_at_end
to save the best model at the end.log_to_tb (
bool
, optional, defaults toFalse
) – IfTrue
, log to Tensorboard.tb_log_dir (
str
, optional, defaults to"./runs"
) – Path of Tensorboard log directory.log_to_wandb (
bool
, optional, defaults toFalse
) – IfTrue
, log to Wandb.wandb_project (
str
, optional, defaults to"textattack"
) – Name of Wandb project for logging.logging_interval_step (
int
, optional, defaults to1
) – Log to Tensorboard/Wandb every N training steps.
- alpha: float = 1.0
- attack_epoch_interval: int = 1
- attack_num_workers_per_device: int = 1
- checkpoint_interval_epochs: int = None
- checkpoint_interval_steps: int = None
- early_stopping_epochs: int = None
- gradient_accumulation_steps: int = 1
- learning_rate: float = 5e-05
- load_best_model_at_end: bool = False
- log_to_tb: bool = False
- log_to_wandb: bool = False
- logging_interval_step: int = 1
- num_clean_epochs: int = 1
- num_epochs: int = 3
- num_train_adv_examples: Union[int, float] = -1
- num_warmup_steps: Union[int, float] = 500
- output_dir: str
- parallel: bool = False
- per_device_eval_batch_size: int = 32
- per_device_train_batch_size: int = 8
- query_budget_train: int = None
- random_seed: int = 786
- save_last: bool = True
- tb_log_dir: str = None
- wandb_project: str = 'textattack'
- weight_decay: float = 0.01