textattack package
Welcome to the API references for TextAttack!
What is TextAttack?
TextAttack is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It’s also useful for NLP model training, adversarial training, and data augmentation.
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
Subpackages
- textattack.attack_recipes package
- Attack Recipes Package:
- Submodules
- A2T (A2T: Attack for Adversarial Training Recipe)
A2TYoo2021- Attack Recipe Class
AttackRecipe- Imperceptible Perturbations Algorithm
BadCharacters2021- BAE (BAE: BERT-Based Adversarial Examples)
BAEGarg2019- BERT-Attack:
BERTAttackLi2020- CheckList:
CheckList2020- Attack Chinese Recipe
ChineseRecipe- CLARE Recipe
CLARE2020- DeepWordBug
DeepWordBugGao2018- Faster Alzantot Genetic Algorithm
FasterGeneticAlgorithmJia2019- Attack French Recipe
FrenchRecipe- Alzantot Genetic Algorithm
GeneticAlgorithmAlzantot2018- HotFlip
HotFlipEbrahimi2017- Improved Genetic Algorithm
IGAWang2019- Input Reduction
InputReductionFeng2018- Kuleshov2017
Kuleshov2017- MORPHEUS2020
MorpheusTan2020- Pruthi2019: Combating with Robust Word Recognition
Pruthi2019- Particle Swarm Optimization
PSOZang2020- PWWS
PWWSRen2019- Seq2Sick
Seq2SickCheng2018BlackBox- Attack Spanish Recipe
SpanishRecipe- TextBugger
TextBuggerLi2018- TextFooler (Is BERT Really Robust?)
TextFoolerJin2019
- textattack.attack_results package
- textattack.augmentation package
- TextAttack augmentation package:
- Submodules
- textattack.commands package
- TextAttack commands Package
- Submodules
- AttackCommand class
AttackCommand- AttackResumeCommand class
AttackResumeCommand- AugmentCommand class
AugmentCommand- BenchmarkRecipeCommand class
BenchmarkRecipeCommand- EvalModelCommand class
EvalModelCommandModelEvalArgs- ListThingsCommand class
ListThingsCommand- PeekDatasetCommand class
PeekDatasetCommand- TextAttack CLI main class
main()TextAttackCommand- TrainModelCommand class
TrainModelCommand
- textattack.constraints package
- Constraints
- Subpackages
- textattack.constraints.grammaticality package
- textattack.constraints.overlap package
- textattack.constraints.pre_transformation package
- Pre-Transformation:
- Submodules
- Input Column Modification
InputColumnModification- Max Modification Rate
MaxModificationRate- Max Modification Rate
MaxNumWordsModified- Max Word Index Modification
MaxWordIndexModification- Min Word Lenth
MinWordLength- Repeat Modification
RepeatModification- Stopword Modification
StopwordModificationUnmodifiableIndicesUnmodifablePhrases
- textattack.constraints.semantics package
- Submodules
- textattack.datasets package
- textattack.goal_function_results package
- Goal Function Result package:
- Subpackages
- textattack.goal_function_results.custom package
- Custom Goal Function Result package:
- Submodules
- textattack.goal_function_results.custom package
- Submodules
- textattack.goal_functions package
- Goal Functions
- Subpackages
- textattack.goal_functions.classification package
- Goal fucntion for Classification
- Submodules
- Determine for if an attack has been successful in Classification
ClassificationGoalFunction- Determine if an attack has been successful in Hard Label Classficiation.
HardLabelClassification- Determine if maintaining the same predicted label (input reduction)
InputReduction- Determine if an attack has been successful in targeted Classification
TargetedClassification- Determine successful in untargeted Classification
UntargetedClassification
- textattack.goal_functions.custom package
- textattack.goal_functions.text package
- textattack.goal_functions.classification package
- Submodules
- textattack.llms package
- textattack.loggers package
- Misc Loggers: Loggers track, visualize, and export attack results.
- Submodules
- Managing Attack Logs.
AttackLogManagerAttackLogManager.add_output_csv()AttackLogManager.add_output_file()AttackLogManager.add_output_summary_json()AttackLogManager.disable_color()AttackLogManager.enable_stdout()AttackLogManager.enable_visdom()AttackLogManager.enable_wandb()AttackLogManager.flush()AttackLogManager.log_attack_details()AttackLogManager.log_result()AttackLogManager.log_results()AttackLogManager.log_sep()AttackLogManager.log_summary()AttackLogManager.log_summary_rows()AttackLogManager.metrics
- Attack Logs to CSV
CSVLogger- Attack Logs to file
FileLogger- Attack Summary Results Logs to Json
JsonSummaryLogger- Attack Logger Wrapper
Logger- Attack Logs to Visdom
VisdomLoggerport_is_open()- Attack Logs to WandB
WeightsAndBiasesLogger
- textattack.metrics package
- metrics package: to calculate advanced metrics for evaluting attacks and augmented text
- Subpackages
- textattack.metrics.attack_metrics package
- textattack.metrics.quality_metrics package
- Submodules
- textattack.models package
- Models
- Subpackages
- textattack.models.helpers package
- Moderl Helpers
- Submodules
- textattack.models.tokenizers package
- textattack.models.wrappers package
- textattack.models.helpers package
- textattack.prompt_augmentation package
- textattack.search_methods package
- Search Methods
- Submodules
- Reimplementation of search method from Generating Natural Language Adversarial Examples
AlzantotGeneticAlgorithm- Beam Search
BeamSearchDifferentialEvolution- Genetic Algorithm Word Swap
GeneticAlgorithm- Greedy Search
GreedySearch- Greedy Word Swap with Word Importance Ranking
GreedyWordSwapWIR- Reimplementation of search method from Xiaosen Wang, Hao Jin, Kun He (2019).
ImprovedGeneticAlgorithm- Particle Swarm Optimization
ParticleSwarmOptimizationnormalize()- Population based Search abstract class
PopulationBasedSearchPopulationMember- Search Method Abstract Class
SearchMethod
- textattack.shared package
- Shared TextAttack Functions
- Subpackages
- textattack.shared.utils package
- Submodules
LazyLoaderload_module_from_file()download_from_s3()download_from_url()http_get()path_in_cache()s3_url()set_cache_dir()unzip_file()get_textattack_model_num_labels()hashable()html_style_from_dict()html_table_from_rows()load_textattack_model_from_path()set_seed()sigmoid()ANSI_ESCAPE_CODESANSI_ESCAPE_CODES.BOLDANSI_ESCAPE_CODES.BROWNANSI_ESCAPE_CODES.CYANANSI_ESCAPE_CODES.FAILANSI_ESCAPE_CODES.GRAYANSI_ESCAPE_CODES.HEADERANSI_ESCAPE_CODES.OKBLUEANSI_ESCAPE_CODES.OKGREENANSI_ESCAPE_CODES.ORANGEANSI_ESCAPE_CODES.PINKANSI_ESCAPE_CODES.PURPLEANSI_ESCAPE_CODES.STOPANSI_ESCAPE_CODES.UNDERLINEANSI_ESCAPE_CODES.WARNINGANSI_ESCAPE_CODES.YELLOW
ReprMixinTextAttackFlairTokenizeradd_indent()check_if_punctuations()check_if_subword()color_from_label()color_from_output()color_text()default_class_repr()flair_tag()has_letter()is_one_word()process_label_name()strip_BPE_artifacts()words_from_text()zip_flair_result()zip_stanza_result()batch_model_predict()
- Submodules
- textattack.shared.utils package
- Submodules
- Attacked Text Class
AttackedTextAttackedText.align_with_model_tokens()AttackedText.all_words_diff()AttackedText.convert_from_original_idxs()AttackedText.delete_word_at_index()AttackedText.first_word_diff()AttackedText.first_word_diff_index()AttackedText.free_memory()AttackedText.generate_new_attacked_text()AttackedText.get_deletion_indices()AttackedText.insert_text_after_word_index()AttackedText.insert_text_before_word_index()AttackedText.ith_word_diff()AttackedText.ner_of_word_index()AttackedText.pos_of_word_index()AttackedText.printable_text()AttackedText.replace_word_at_index()AttackedText.replace_words_at_indices()AttackedText.text_after_word_index()AttackedText.text_until_word_index()AttackedText.text_window_around_index()AttackedText.words_diff_num()AttackedText.words_diff_ratio()AttackedText.SPLIT_TOKENAttackedText.column_labelsAttackedText.newly_swapped_wordsAttackedText.num_wordsAttackedText.textAttackedText.tokenizer_inputAttackedText.wordsAttackedText.words_per_input
- Misc Checkpoints
AttackCheckpointAttackCheckpoint.load()AttackCheckpoint.save()AttackCheckpoint.dataset_offsetAttackCheckpoint.datetimeAttackCheckpoint.num_failed_attacksAttackCheckpoint.num_maximized_attacksAttackCheckpoint.num_remaining_attacksAttackCheckpoint.num_skipped_attacksAttackCheckpoint.num_successful_attacksAttackCheckpoint.results_count
- Shared data fields
- Misc Validators
transformation_consists_of()transformation_consists_of_word_swaps()transformation_consists_of_word_swaps_and_deletions()transformation_consists_of_word_swaps_differential_evolution()validate_model_goal_function_compatibility()validate_model_gradient_word_swap_compatibility()- Shared loads word embeddings and related distances
AbstractWordEmbeddingGensimWordEmbeddingWordEmbedding
- textattack.transformations package
- Transformations
- Subpackages
- textattack.transformations.sentence_transformations package
- textattack.transformations.word_insertions package
- textattack.transformations.word_merges package
- textattack.transformations.word_swaps package
- word_swaps package
- Subpackages
- Submodules
- Word Swap
WordSwap- Word Swap by Changing Location
WordSwapChangeLocationidx_to_words()- Word Swap by Changing Name
WordSwapChangeName- Word Swap by Changing Number
WordSwapChangeNumberidx_to_words()- Word Swap by Contraction
WordSwapContract- Word Swap by Invisible Deletions
WordSwapDeletions- Word Swap for Differential Evolution
WordSwapDifferentialEvolution- Word Swap by Embedding
WordSwapEmbeddingrecover_word_case()- Word Swap by Extension
WordSwapExtend- Word Swap by Gradient
WordSwapGradientBased- Word Swap by Homoglyph
WordSwapHomoglyphSwap- Word Swap by OpenHowNet
WordSwapHowNetrecover_word_case()- Word Swap by inflections
WordSwapInflections- Word Swap by Invisible Characters
WordSwapInvisibleCharacters- Word Swap by BERT-Masked LM.
WordSwapMaskedLMrecover_word_case()- Word Swap by Neighboring Character Swap
WordSwapNeighboringCharacterSwap- Word Swap by swaps characters with QWERTY adjacent keys
WordSwapQWERTY- Word Swap by Random Character Deletion
WordSwapRandomCharacterDeletion- Word Swap by Random Character Insertion
WordSwapRandomCharacterInsertion- Word Swap by Random Character Substitution
WordSwapRandomCharacterSubstitution- Word Swap by Invisible Reorderings
WordSwapReorderings- Word Swap by swapping synonyms in WordNet
WordSwapWordNet
- Submodules
Submodules
Attack Class
- class textattack.attack.Attack(goal_function: GoalFunction, constraints: List[Constraint | PreTransformationConstraint], transformation: Transformation, search_method: SearchMethod, transformation_cache_size=32768, constraint_cache_size=32768)[source]
Bases:
objectAn attack generates adversarial examples on text.
An attack is comprised of a goal function, constraints, transformation, and a search method. Use
attack()method to attack one sample at a time.- Parameters:
goal_function (
GoalFunction) – A function for determining how well a perturbation is doing at achieving the attack’s goal.constraints (list of
ConstraintorPreTransformationConstraint) – A list of constraints to add to the attack, defining which perturbations are valid.transformation (
Transformation) – The transformation applied at each step of the attack.search_method (
SearchMethod) – The method for exploring the search space of possible perturbationstransformation_cache_size (
int, optional, defaults to2**15) – The number of items to keep in the transformations cacheconstraint_cache_size (
int, optional, defaults to2**15) – The number of items to keep in the constraints cache
Example:
>>> import textattack >>> import transformers >>> # Load model, tokenizer, and model_wrapper >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # Construct our four components for `Attack` >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification >>> from textattack.constraints.semantics import WordEmbeddingDistance >>> from textattack.transformations import WordSwapEmbedding >>> from textattack.search_methods import GreedyWordSwapWIR >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) >>> constraints = [ ... RepeatModification(), ... StopwordModification(), ... WordEmbeddingDistance(min_cos_sim=0.9) ... ] >>> transformation = WordSwapEmbedding(max_candidates=50) >>> search_method = GreedyWordSwapWIR(wir_method="delete") >>> # Construct the actual attack >>> attack = textattack.Attack(goal_function, constraints, transformation, search_method) >>> input_text = "I really enjoyed the new movie that came out last month." >>> label = 1 #Positive >>> attack_result = attack.attack(input_text, label)
- attack(example, ground_truth_output)[source]
Attack a single example.
- Parameters:
example (
str,OrderedDict[str, str]orAttackedText) – Example to attack. It can be a single string or an OrderedDict where keys represent the input fields (e.g. “premise”, “hypothesis”) and the values are the actual input textx. Also acceptsAttackedTextthat wraps around the input.ground_truth_output (
int,floatorstr) – Ground truth output of example. For classification tasks, it should be an integer representing the ground truth label. For regression tasks (e.g. STS), it should be the target value. For seq2seq tasks (e.g. translation), it should be the target string.
- Returns:
AttackResultthat represents the result of the attack.
- filter_transformations(transformed_texts, current_text, original_text=None)[source]
Filters a list of potential transformed texts based on
self.constraintsUtilizes an LRU cache to attempt to avoid recomputing common transformations.- Parameters:
transformed_texts – A list of candidate transformed
AttackedTextto filter.current_text – The current
AttackedTexton which the transformation was applied.original_text – The original
AttackedTextfrom which the attack started.
- get_indices_to_order(current_text, **kwargs)[source]
Applies
pre_transformation_constraintstotextto get all the indices that can be used to search and order.- Parameters:
current_text – The current
AttackedTextfor which we need to find indices are eligible to be ordered.- Returns:
The length and the filtered list of indices which search methods can use to search/order.
- get_transformations(current_text, original_text=None, **kwargs)[source]
Applies
self.transformationtotext, then filters the list of possible transformations through the applicable constraints.- Parameters:
current_text – The current
AttackedTexton which to perform the transformations.original_text – The original
AttackedTextfrom which the attack started.
- Returns:
A filtered list of transformations where each transformation matches the constraints
AttackArgs Class
- class textattack.attack_args.AttackArgs(num_examples: int = 10, num_successful_examples: int | None = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int | None = None, checkpoint_interval: int | None = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, parallel: bool = False, num_workers_per_device: int = 1, log_to_txt: str | None = None, log_to_csv: str | None = None, log_summary_to_json: str | None = None, csv_coloring_style: str = 'file', log_to_visdom: dict | None = None, log_to_wandb: dict | None = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Dict | None = None)[source]
Bases:
objectAttack arguments to be passed to
Attacker.- Parameters:
num_examples (
int, ‘optional`, defaults to10) – The number of examples to attack.-1for entire dataset.num_successful_examples (
int, optional, defaults toNone) –The number of successful adversarial examples we want. This is different from
num_examplesasnum_examplesonly cares about attacking N samples whilenum_successful_examplesaims to keep attacking until we have N successful cases. .. note:If set, this argument overrides `num_examples` argument.
( (num_examples_offset) – obj: int, optional, defaults to
0): The offset index to start at in the dataset.attack_n (
bool, optional, defaults toFalse) – Whether to run attack until total of N examples have been attacked (and not skipped).shuffle (
bool, optional, defaults toFalse) – IfTrue, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling the dataset internally and opts for shuffling the list of indices of examples we want to attack. This meansshufflecan now be used with checkpoint saving.query_budget (
int, optional, defaults toNone) –The maximum number of model queries allowed per example attacked. If not set, we use the query budget set in the
GoalFunctionobject (which by default isfloat("inf")). .. note:Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
checkpoint_interval (
int, optional, defaults toNone) – If set, checkpoint will be saved after attacking every N examples. IfNoneis passed, no checkpoints will be saved.checkpoint_dir (
str, optional, defaults to"checkpoints") – The directory to save checkpoint files.random_seed (
int, optional, defaults to765) – Random seed for reproducibility.parallel (
False, optional, defaults toFalse) – IfTrue, run attack using multiple CPUs/GPUs.num_workers_per_device (
int, optional, defaults to1) – Number of worker processes to run per device in parallel mode (i.e.parallel=True). For example, if you are using GPUs andnum_workers_per_device=2, then 2 processes will be running in each GPU.log_to_txt (
str, optional, defaults toNone) – If set, save attack logs as a .txt file to the directory specified by this argument. If the last part of the provided path ends with .txt extension, it is assumed to the desired path of the log file.log_to_csv (
str, optional, defaults toNone) – If set, save attack logs as a CSV file to the directory specified by this argument. If the last part of the provided path ends with .csv extension, it is assumed to the desired path of the log file.csv_coloring_style (
str, optional, defaults to"file") – Method for choosing how to mark perturbed parts of the text. Options are"file","plain", and"html"."file"wraps perturbed parts with double brackets[[ <text> ]]while"plain"does not mark the text in any way.log_to_visdom (
dict, optional, defaults toNone) – If set, Visdom logger is used with the provided dictionary passed as a keyword arguments toVisdomLogger. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following three keys and their corresponding values:"env", "port", "hostname".log_to_wandb (
dict, optional, defaults toNone) – If set, WandB logger is used with the provided dictionary passed as a keyword arguments toWeightsAndBiasesLogger. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following key and its corresponding value:"project".disable_stdout (
bool, optional, defaults toFalse) – Disable displaying individual attack results to stdout.silent (
bool, optional, defaults toFalse) – Disable all logging (except for errors). This is stronger thandisable_stdout.enable_advance_metrics (
bool, optional, defaults toFalse) – Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
- classmethod create_loggers_from_args(args)[source]
Creates AttackLogManager from an AttackArgs object.
- attack_n: bool = False
- checkpoint_dir: str = 'checkpoints'
- checkpoint_interval: int = None
- csv_coloring_style: str = 'file'
- disable_stdout: bool = False
- enable_advance_metrics: bool = False
- log_summary_to_json: str = None
- log_to_csv: str = None
- log_to_txt: str = None
- log_to_visdom: dict = None
- log_to_wandb: dict = None
- metrics: Dict | None = None
- num_examples: int = 10
- num_examples_offset: int = 0
- num_successful_examples: int = None
- num_workers_per_device: int = 1
- parallel: bool = False
- query_budget: int = None
- random_seed: int = 765
- shuffle: bool = False
- silent: bool = False
- class textattack.attack_args.CommandLineAttackArgs(model: str = None, model_from_file: str = None, model_from_huggingface: str = None, dataset_by_model: str = None, dataset_from_huggingface: str = None, dataset_from_file: str = None, dataset_split: str = None, filter_by_labels: list = None, transformation: str = 'word-swap-embedding', constraints: list = <factory>, goal_function: str = 'untargeted-classification', search_method: str = 'greedy-word-wir', attack_recipe: str = None, attack_from_file: str = None, interactive: bool = False, parallel: bool = False, model_batch_size: int = 32, model_cache_size: int = 262144, constraint_cache_size: int = 262144, num_examples: int = 10, num_successful_examples: int = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int = None, checkpoint_interval: int = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, num_workers_per_device: int = 1, log_to_txt: str = None, log_to_csv: str = None, log_summary_to_json: str = None, csv_coloring_style: str = 'file', log_to_visdom: dict = None, log_to_wandb: dict = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Union[Dict, NoneType] = None)[source]
Bases:
AttackArgs,_CommandLineAttackArgs,DatasetArgs,ModelArgs
Attacker Class
- class textattack.attacker.Attacker(attack, dataset, attack_args=None)[source]
Bases:
objectClass for running attacks on a dataset with specified parameters. This class uses the
Attackto actually run the attacks, while also providing useful features such as parallel processing, saving/resuming from a checkpint, logging to files and stdout.- Parameters:
attack (
Attack) –Attackused to actually carry out the attack.dataset (
Dataset) – Dataset to attack.attack_args (
AttackArgs) – Arguments for attacking the dataset. For default settings, look at the AttackArgs class.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) >>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval >>> attack_args = textattack.AttackArgs( ... num_examples=20, ... log_to_csv="log.csv", ... checkpoint_interval=5, ... checkpoint_dir="checkpoints", ... disable_stdout=True ... ) >>> attacker = textattack.Attacker(attack, dataset, attack_args) >>> attacker.attack_dataset()
- attack_dataset()[source]
Attack the dataset.
- Returns:
list[AttackResult]- List ofAttackResultobtained after attacking the given dataset..
- classmethod from_checkpoint(attack, dataset, checkpoint)[source]
Resume attacking from a saved checkpoint. Attacker and dataset must be recovered by the user again, while attack args are loaded from the saved checkpoint.
- update_attack_args(**kwargs)[source]
To update any attack args, pass the new argument as keyword argument to this function.
Examples:
>>> attacker = #some instance of Attacker >>> # To switch to parallel mode and increase checkpoint interval from 100 to 500 >>> attacker.update_attack_args(parallel=True, checkpoint_interval=500)
- textattack.attacker.attack_from_queue(attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue)[source]
AugmenterArgs Class
- class textattack.augment_args.AugmenterArgs(input_csv: str, output_csv: str, input_column: str, recipe: str = 'embedding', pct_words_to_swap: float = 0.1, transformations_per_example: int = 2, random_seed: int = 42, exclude_original: bool = False, overwrite: bool = False, interactive: bool = False, fast_augment: bool = False, high_yield: bool = False, enable_advanced_metrics: bool = False)[source]
Bases:
objectArguments for performing data augmentation.
- Parameters:
input_csv (str) – Path of input CSV file to augment.
output_csv (str) – Path of CSV file to output augmented data.
- enable_advanced_metrics: bool = False
- exclude_original: bool = False
- fast_augment: bool = False
- high_yield: bool = False
- input_column: str
- input_csv: str
- interactive: bool = False
- output_csv: str
- overwrite: bool = False
- pct_words_to_swap: float = 0.1
- random_seed: int = 42
- recipe: str = 'embedding'
- transformations_per_example: int = 2
DatasetArgs Class
- class textattack.dataset_args.DatasetArgs(dataset_by_model: str | None = None, dataset_from_huggingface: str | None = None, dataset_from_file: str | None = None, dataset_split: str | None = None, filter_by_labels: list | None = None)[source]
Bases:
objectArguments for loading dataset from command line input.
- dataset_by_model: str = None
- dataset_from_file: str = None
- dataset_from_huggingface: str = None
- dataset_split: str = None
- filter_by_labels: list = None
ModelArgs Class
- class textattack.model_args.ModelArgs(model: str | None = None, model_from_file: str | None = None, model_from_huggingface: str | None = None)[source]
Bases:
objectArguments for loading base/pretrained or trained models.
- model: str = None
- model_from_file: str = None
- model_from_huggingface: str = None
Trainer Class
- class textattack.trainer.Trainer(model_wrapper, task_type='classification', attack=None, train_dataset=None, eval_dataset=None, training_args=None)[source]
Bases:
objectTrainer is training and eval loop for adversarial training.
It is designed to work with PyTorch and Transformers models.
- Parameters:
model_wrapper (
ModelWrapper) – Model wrapper containing both the model and the tokenizer.task_type (
str, optional, defaults to"classification") – The task that the model is trained to perform. Currently,Trainersupports two tasks: (1)"classification", (2)"regression".attack (
Attack) –Attackused to generate adversarial examples for training.train_dataset (
Dataset) – Dataset for training.eval_dataset (
Dataset) – Dataset for evaluationtraining_args (
TrainingArgs) – Arguments for training.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # We only use DeepWordBugGao2018 to demonstration purposes. >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). >>> training_args = textattack.TrainingArgs( ... num_epochs=3, ... num_clean_epochs=1, ... num_train_adv_examples=1000, ... learning_rate=5e-5, ... per_device_train_batch_size=8, ... gradient_accumulation_steps=4, ... log_to_tb=True, ... ) >>> trainer = textattack.Trainer( ... model_wrapper, ... "classification", ... attack, ... train_dataset, ... eval_dataset, ... training_args ... ) >>> trainer.train()
Note
When using
Trainerwith parallel=True inTrainingArgs, make sure to protect the “entry point” of the program by usingif __name__ == '__main__':. If not, each worker process used for generating adversarial examples will execute the training code again.- evaluate_step(model, tokenizer, batch)[source]
Perform a single evaluation step on a batch of inputs.
- Parameters:
model (
torch.nn.Module) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor]) –By default, this will be a tuple of input texts and target tensors.
Note
If you override the
get_eval_dataloader()method, then shape/type ofbatchwill depend on how you created your batch.
- Returns:
tuple[torch.Tensor, torch.Tensor]wherepreds:
torch.FloatTensorof model’s prediction for the batch.targets:
torch.Tensorof model’s targets (e.g. labels, target values).
- get_eval_dataloader(dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoaderfor evaluation.- Parameters:
dataset (
Dataset) – Dataset to use for evaluation.batch_size (
int) – Batch size for evaluation.
- Returns:
torch.utils.data.DataLoader
- get_optimizer_and_scheduler(model, num_training_steps)[source]
Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return
Nonefor scheduler.- Parameters:
model (
torch.nn.Module) – Model to be trained. Pass its parameters to optimizer for training.num_training_steps (
int) – Number of total training steps.
- Returns:
Tuple of optimizer and scheduler
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]
- get_train_dataloader(dataset, adv_dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoaderfor training.
- training_step(model, tokenizer, batch)[source]
Perform a single training step on a batch of inputs.
- Parameters:
model (
torch.nn.Module) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor, torch.Tensor]) –By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
Note
If you override the
get_train_dataloader()method, then shape/type ofbatchwill depend on how you created your batch.
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]whereloss:
torch.FloatTensorof shape 1 containing the loss.preds:
torch.FloatTensorof model’s prediction for the batch.targets:
torch.Tensorof model’s targets (e.g. labels, target values).
TrainingArgs Class
- class textattack.training_args.CommandLineTrainingArgs(model_name_or_path: str, attack: str, dataset: str, task_type: str = 'classification', model_max_length: int = None, model_num_labels: int = None, dataset_train_split: str = None, dataset_eval_split: str = None, filter_train_by_labels: list = None, filter_eval_by_labels: list = None, num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int = None, learning_rate: float = 5e-05, num_warmup_steps: Union[int, float] = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: Union[int, float] = -1, query_budget_train: int = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int = None, checkpoint_interval_epochs: int = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
TrainingArgs,_CommandLineTrainingArgs- output_dir: str
- class textattack.training_args.TrainingArgs(num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int | None = None, learning_rate: float = 5e-05, num_warmup_steps: int | float = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: int | float = -1, query_budget_train: int | None = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int | None = None, checkpoint_interval_epochs: int | None = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str | None = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
objectArguments for
Trainerclass that is used for adversarial training.- Parameters:
num_epochs (
int, optional, defaults to3) – Total number of epochs for training.num_clean_epochs (
int, optional, defaults to1) – Number of epochs to train on just the original training dataset before adversarial training.attack_epoch_interval (
int, optional, defaults to1) – Generate a new adversarial training set every N epochs.early_stopping_epochs (
int, optional, defaults toNone) – Number of epochs validation must increase before stopping early (Nonefor no early stopping).learning_rate (
float, optional, defaults to5e-5) – Learning rate for optimizer.num_warmup_steps (
intorfloat, optional, defaults to500) – The number of steps for the warmup phase of linear scheduler. Ifnum_warmup_stepsis afloatbetween 0 and 1, the number of warmup steps will bemath.ceil(num_training_steps * num_warmup_steps).weight_decay (
float, optional, defaults to0.01) – Weight decay (L2 penalty).per_device_train_batch_size (
int, optional, defaults to8) – The batch size per GPU/CPU for training.per_device_eval_batch_size (
int, optional, defaults to32) – The batch size per GPU/CPU for evaluation.gradient_accumulation_steps (
int, optional, defaults to1) – Number of updates steps to accumulate the gradients before performing a backward/update pass.random_seed (
int, optional, defaults to786) – Random seed for reproducibility.parallel (
bool, optional, defaults toFalse) – IfTrue, train using multiple GPUs usingtorch.DataParallel.load_best_model_at_end (
bool, optional, defaults toFalse) – IfTrue, keep track of the best model across training and load it at the end.alpha (
float, optional, defaults to1.0) – The weight for adversarial loss.num_train_adv_examples (
intorfloat, optional, defaults to-1) – The number of samples to successfully attack when generating adversarial training set before start of every epoch. Ifnum_train_adv_examplesis afloatbetween 0 and 1, the number of adversarial examples generated is fraction of the original training set.query_budget_train (
int, optional, defaults toNone) – The max query budget to use when generating adversarial training set.Nonemeans infinite query budget.attack_num_workers_per_device (
int, defaults to optional,1) – Number of worker processes to run per device for attack. Same asnum_workers_per_deviceargument forAttackArgs.output_dir (
str, optional) – Directory to output training logs and checkpoints. Defaults to/outputs/%Y-%m-%d-%H-%M-%S-%fformat.checkpoint_interval_steps (
int, optional, defaults toNone) – If set, save model checkpoint after every N updates to the model.checkpoint_interval_epochs (
int, optional, defaults toNone) – If set, save model checkpoint after every N epochs.save_last (
bool, optional, defaults toTrue) – IfTrue, save the model at end of training. Can be used withload_best_model_at_endto save the best model at the end.log_to_tb (
bool, optional, defaults toFalse) – IfTrue, log to Tensorboard.tb_log_dir (
str, optional, defaults to"./runs") – Path of Tensorboard log directory.log_to_wandb (
bool, optional, defaults toFalse) – IfTrue, log to Wandb.wandb_project (
str, optional, defaults to"textattack") – Name of Wandb project for logging.logging_interval_step (
int, optional, defaults to1) – Log to Tensorboard/Wandb every N training steps.
- alpha: float = 1.0
- attack_epoch_interval: int = 1
- attack_num_workers_per_device: int = 1
- checkpoint_interval_epochs: int = None
- checkpoint_interval_steps: int = None
- early_stopping_epochs: int = None
- gradient_accumulation_steps: int = 1
- learning_rate: float = 5e-05
- load_best_model_at_end: bool = False
- log_to_tb: bool = False
- log_to_wandb: bool = False
- logging_interval_step: int = 1
- num_clean_epochs: int = 1
- num_epochs: int = 3
- num_train_adv_examples: int | float = -1
- num_warmup_steps: int | float = 500
- output_dir: str
- parallel: bool = False
- per_device_eval_batch_size: int = 32
- per_device_train_batch_size: int = 8
- query_budget_train: int = None
- random_seed: int = 786
- save_last: bool = True
- tb_log_dir: str = None
- wandb_project: str = 'textattack'
- weight_decay: float = 0.01