textattack package
Welcome to the API references for TextAttack!
What is TextAttack?
TextAttack is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It’s also useful for NLP model training, adversarial training, and data augmentation.
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
- textattack.attack_recipes package
- Attack Recipes Package:
- A2T (A2T: Attack for Adversarial Training Recipe)
A2TYoo2021
- Attack Recipe Class
AttackRecipe
- BAE (BAE: BERT-Based Adversarial Examples)
BAEGarg2019
- BERT-Attack:
BERTAttackLi2020
- CheckList:
CheckList2020
- CLARE Recipe
CLARE2020
- DeepWordBug
DeepWordBugGao2018
- Faster Alzantot Genetic Algorithm
FasterGeneticAlgorithmJia2019
- Alzantot Genetic Algorithm
GeneticAlgorithmAlzantot2018
- HotFlip
HotFlipEbrahimi2017
- Improved Genetic Algorithm
IGAWang2019
- Input Reduction
InputReductionFeng2018
- Kuleshov2017
Kuleshov2017
- MORPHEUS2020
MorpheusTan2020
- Pruthi2019: Combating with Robust Word Recognition
Pruthi2019
- Particle Swarm Optimization
PSOZang2020
- PWWS
PWWSRen2019
- Seq2Sick
Seq2SickCheng2018BlackBox
- TextBugger
TextBuggerLi2018
- TextFooler (Is BERT Really Robust?)
TextFoolerJin2019
- textattack.attack_results package
- textattack.augmentation package
- textattack.commands package
- TextAttack commands Package
- AttackCommand class
AttackCommand
- AttackResumeCommand class
AttackResumeCommand
- AugmentCommand class
AugmentCommand
- BenchmarkRecipeCommand class
BenchmarkRecipeCommand
- EvalModelCommand class
EvalModelCommand
ModelEvalArgs
- ListThingsCommand class
ListThingsCommand
- PeekDatasetCommand class
PeekDatasetCommand
- TextAttack CLI main class
main()
TextAttackCommand
- TrainModelCommand class
TrainModelCommand
- textattack.constraints package
- Constraints
- textattack.constraints.grammaticality package
- Grammaticality:
- textattack.constraints.grammaticality.language_models package
- non-pre Language Models:
- textattack.constraints.grammaticality.language_models.google_language_model package
- textattack.constraints.grammaticality.language_models.learning_to_write package
- GPT2 Language Models:
GPT2
- Language Models Constraint
LanguageModelConstraint
- CoLA for Grammaticality
COLA
- LanguageTool Grammar Checker
LanguageTool
- Part of Speech Constraint
PartOfSpeech
- textattack.constraints.overlap package
- textattack.constraints.pre_transformation package
- textattack.constraints.semantics package
- Semantic Constraints
- textattack.constraints.semantics.sentence_encoders package
- Sentence Encoder Constraint
- textattack.constraints.semantics.sentence_encoders.bert package
- textattack.constraints.semantics.sentence_encoders.infer_sent package
- infer sent
- infer sent for sentence similarity
InferSent
- Infer sent model
InferSentModel
InferSentModel.build_vocab()
InferSentModel.build_vocab_k_words()
InferSentModel.encode()
InferSentModel.forward()
InferSentModel.get_batch()
InferSentModel.get_w2v()
InferSentModel.get_w2v_k()
InferSentModel.get_word_dict()
InferSentModel.is_cuda()
InferSentModel.prepare_samples()
InferSentModel.set_w2v_path()
InferSentModel.tokenize()
InferSentModel.update_vocab()
InferSentModel.training
- textattack.constraints.semantics.sentence_encoders.universal_sentence_encoder package
- Sentence Encoder Class
SentenceEncoder
get_angular_sim()
get_neg_euclidean_dist()
- Thought Vector Class
ThoughtVector
- BERT Score
BERTScore
- Word Embedding Distance
WordEmbeddingDistance
- TextAttack Constraint Class
Constraint
- Pre-Transformation Constraint Class
PreTransformationConstraint
- textattack.datasets package
- textattack.goal_function_results package
- textattack.goal_functions package
- Goal Functions
- textattack.goal_functions.classification package
- Goal fucntion for Classification
- Determine for if an attack has been successful in Classification
ClassificationGoalFunction
- Determine if maintaining the same predicted label (input reduction)
InputReduction
- Determine if an attack has been successful in targeted Classification
TargetedClassification
- Determine successful in untargeted Classification
UntargetedClassification
- textattack.goal_functions.text package
- GoalFunction Class
GoalFunction
- textattack.loggers package
- Misc Loggers: Loggers track, visualize, and export attack results.
- Managing Attack Logs.
AttackLogManager
AttackLogManager.add_output_csv()
AttackLogManager.add_output_file()
AttackLogManager.add_output_summary_json()
AttackLogManager.disable_color()
AttackLogManager.enable_stdout()
AttackLogManager.enable_visdom()
AttackLogManager.enable_wandb()
AttackLogManager.flush()
AttackLogManager.log_attack_details()
AttackLogManager.log_result()
AttackLogManager.log_results()
AttackLogManager.log_sep()
AttackLogManager.log_summary()
AttackLogManager.log_summary_rows()
AttackLogManager.metrics
- Attack Logs to CSV
CSVLogger
- Attack Logs to file
FileLogger
- Attack Logger Wrapper
Logger
- Attack Logs to Visdom
VisdomLogger
port_is_open()
- Attack Logs to WandB
WeightsAndBiasesLogger
- textattack.metrics package
- metrics package: to calculate advanced metrics for evaluting attacks and augmented text
- textattack.metrics.attack_metrics package
- textattack.metrics.quality_metrics package
- Metric Class
Metric
- textattack.models package
- Models
- textattack.models.helpers package
- textattack.models.tokenizers package
- textattack.models.wrappers package
- textattack.search_methods package
- Search Methods
- Reimplementation of search method from Generating Natural Language Adversarial Examples
AlzantotGeneticAlgorithm
- Beam Search
BeamSearch
- Genetic Algorithm Word Swap
GeneticAlgorithm
- Greedy Search
GreedySearch
- Greedy Word Swap with Word Importance Ranking
GreedyWordSwapWIR
- Reimplementation of search method from Xiaosen Wang, Hao Jin, Kun He (2019).
ImprovedGeneticAlgorithm
- Particle Swarm Optimization
ParticleSwarmOptimization
normalize()
- Population based Search abstract class
PopulationBasedSearch
PopulationMember
- Search Method Abstract Class
SearchMethod
- textattack.shared package
- Shared TextAttack Functions
- textattack.shared.utils package
LazyLoader
load_module_from_file()
download_from_s3()
download_from_url()
http_get()
path_in_cache()
s3_url()
set_cache_dir()
unzip_file()
get_textattack_model_num_labels()
hashable()
html_style_from_dict()
html_table_from_rows()
load_textattack_model_from_path()
set_seed()
sigmoid()
ANSI_ESCAPE_CODES
ANSI_ESCAPE_CODES.BOLD
ANSI_ESCAPE_CODES.BROWN
ANSI_ESCAPE_CODES.CYAN
ANSI_ESCAPE_CODES.FAIL
ANSI_ESCAPE_CODES.GRAY
ANSI_ESCAPE_CODES.HEADER
ANSI_ESCAPE_CODES.OKBLUE
ANSI_ESCAPE_CODES.OKGREEN
ANSI_ESCAPE_CODES.ORANGE
ANSI_ESCAPE_CODES.PINK
ANSI_ESCAPE_CODES.PURPLE
ANSI_ESCAPE_CODES.STOP
ANSI_ESCAPE_CODES.UNDERLINE
ANSI_ESCAPE_CODES.WARNING
ANSI_ESCAPE_CODES.YELLOW
ReprMixin
TextAttackFlairTokenizer
add_indent()
check_if_punctuations()
check_if_subword()
color_from_label()
color_from_output()
color_text()
default_class_repr()
flair_tag()
has_letter()
is_one_word()
process_label_name()
strip_BPE_artifacts()
words_from_text()
zip_flair_result()
zip_stanza_result()
batch_model_predict()
- Attacked Text Class
AttackedText
AttackedText.align_with_model_tokens()
AttackedText.all_words_diff()
AttackedText.convert_from_original_idxs()
AttackedText.delete_word_at_index()
AttackedText.first_word_diff()
AttackedText.first_word_diff_index()
AttackedText.free_memory()
AttackedText.generate_new_attacked_text()
AttackedText.get_deletion_indices()
AttackedText.insert_text_after_word_index()
AttackedText.insert_text_before_word_index()
AttackedText.ith_word_diff()
AttackedText.ner_of_word_index()
AttackedText.pos_of_word_index()
AttackedText.printable_text()
AttackedText.replace_word_at_index()
AttackedText.replace_words_at_indices()
AttackedText.text_after_word_index()
AttackedText.text_until_word_index()
AttackedText.text_window_around_index()
AttackedText.words_diff_num()
AttackedText.words_diff_ratio()
AttackedText.SPLIT_TOKEN
AttackedText.column_labels
AttackedText.newly_swapped_words
AttackedText.num_words
AttackedText.text
AttackedText.tokenizer_input
AttackedText.words
AttackedText.words_per_input
- Misc Checkpoints
AttackCheckpoint
AttackCheckpoint.load()
AttackCheckpoint.save()
AttackCheckpoint.dataset_offset
AttackCheckpoint.datetime
AttackCheckpoint.num_failed_attacks
AttackCheckpoint.num_maximized_attacks
AttackCheckpoint.num_remaining_attacks
AttackCheckpoint.num_skipped_attacks
AttackCheckpoint.num_successful_attacks
AttackCheckpoint.results_count
- Shared data fields
- Misc Validators
transformation_consists_of()
transformation_consists_of_word_swaps()
transformation_consists_of_word_swaps_and_deletions()
validate_model_goal_function_compatibility()
validate_model_gradient_word_swap_compatibility()
- Shared loads word embeddings and related distances
AbstractWordEmbedding
GensimWordEmbedding
WordEmbedding
- textattack.transformations package
- Transformations
- textattack.transformations.sentence_transformations package
- textattack.transformations.word_insertions package
- textattack.transformations.word_merges package
- textattack.transformations.word_swaps package
- word_swaps package
- Word Swap
WordSwap
- Word Swap by Changing Location
WordSwapChangeLocation
idx_to_words()
- Word Swap by Changing Name
WordSwapChangeName
- Word Swap by Changing Number
WordSwapChangeNumber
idx_to_words()
- Word Swap by Contraction
WordSwapContract
- Word Swap by Embedding
WordSwapEmbedding
recover_word_case()
- Word Swap by Extension
WordSwapExtend
- Word Swap by Gradient
WordSwapGradientBased
- Word Swap by Homoglyph
WordSwapHomoglyphSwap
- Word Swap by OpenHowNet
WordSwapHowNet
recover_word_case()
- Word Swap by inflections
WordSwapInflections
- Word Swap by BERT-Masked LM.
WordSwapMaskedLM
recover_word_case()
- Word Swap by Neighboring Character Swap
WordSwapNeighboringCharacterSwap
- Word Swap by swaps characters with QWERTY adjacent keys
WordSwapQWERTY
- Word Swap by Random Character Deletion
WordSwapRandomCharacterDeletion
- Word Swap by Random Character Insertion
WordSwapRandomCharacterInsertion
- Word Swap by Random Character Substitution
WordSwapRandomCharacterSubstitution
- Word Swap by swapping synonyms in WordNet
WordSwapWordNet
- Composite Transformation
CompositeTransformation
- Transformation Abstract Class
Transformation
- word deletion Transformation
WordDeletion
- Word Swap Transformation by swapping the order of words
WordInnerSwapRandom
Attack Class
- class textattack.attack.Attack(goal_function: GoalFunction, constraints: List[Constraint | PreTransformationConstraint], transformation: Transformation, search_method: SearchMethod, transformation_cache_size=32768, constraint_cache_size=32768)[source]
Bases:
object
An attack generates adversarial examples on text.
An attack is comprised of a goal function, constraints, transformation, and a search method. Use
attack()
method to attack one sample at a time.- Parameters:
goal_function (
GoalFunction
) – A function for determining how well a perturbation is doing at achieving the attack’s goal.constraints (list of
Constraint
orPreTransformationConstraint
) – A list of constraints to add to the attack, defining which perturbations are valid.transformation (
Transformation
) – The transformation applied at each step of the attack.search_method (
SearchMethod
) – The method for exploring the search space of possible perturbationstransformation_cache_size (
int
, optional, defaults to2**15
) – The number of items to keep in the transformations cacheconstraint_cache_size (
int
, optional, defaults to2**15
) – The number of items to keep in the constraints cache
Example:
>>> import textattack >>> import transformers >>> # Load model, tokenizer, and model_wrapper >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # Construct our four components for `Attack` >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification >>> from textattack.constraints.semantics import WordEmbeddingDistance >>> from textattack.transformations import WordSwapEmbedding >>> from textattack.search_methods import GreedyWordSwapWIR >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) >>> constraints = [ ... RepeatModification(), ... StopwordModification(), ... WordEmbeddingDistance(min_cos_sim=0.9) ... ] >>> transformation = WordSwapEmbedding(max_candidates=50) >>> search_method = GreedyWordSwapWIR(wir_method="delete") >>> # Construct the actual attack >>> attack = textattack.Attack(goal_function, constraints, transformation, search_method) >>> input_text = "I really enjoyed the new movie that came out last month." >>> label = 1 #Positive >>> attack_result = attack.attack(input_text, label)
- attack(example, ground_truth_output)[source]
Attack a single example.
- Parameters:
example (
str
,OrderedDict[str, str]
orAttackedText
) – Example to attack. It can be a single string or an OrderedDict where keys represent the input fields (e.g. “premise”, “hypothesis”) and the values are the actual input textx. Also acceptsAttackedText
that wraps around the input.ground_truth_output (
int
,float
orstr
) – Ground truth output of example. For classification tasks, it should be an integer representing the ground truth label. For regression tasks (e.g. STS), it should be the target value. For seq2seq tasks (e.g. translation), it should be the target string.
- Returns:
AttackResult
that represents the result of the attack.
- filter_transformations(transformed_texts, current_text, original_text=None)[source]
Filters a list of potential transformed texts based on
self.constraints
Utilizes an LRU cache to attempt to avoid recomputing common transformations.- Parameters:
transformed_texts – A list of candidate transformed
AttackedText
to filter.current_text – The current
AttackedText
on which the transformation was applied.original_text – The original
AttackedText
from which the attack started.
- get_indices_to_order(current_text, **kwargs)[source]
Applies
pre_transformation_constraints
totext
to get all the indices that can be used to search and order.- Parameters:
current_text – The current
AttackedText
for which we need to find indices are eligible to be ordered.- Returns:
The length and the filtered list of indices which search methods can use to search/order.
- get_transformations(current_text, original_text=None, **kwargs)[source]
Applies
self.transformation
totext
, then filters the list of possible transformations through the applicable constraints.- Parameters:
current_text – The current
AttackedText
on which to perform the transformations.original_text – The original
AttackedText
from which the attack started.
- Returns:
A filtered list of transformations where each transformation matches the constraints
AttackArgs Class
- class textattack.attack_args.AttackArgs(num_examples: int = 10, num_successful_examples: int | None = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int | None = None, checkpoint_interval: int | None = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, parallel: bool = False, num_workers_per_device: int = 1, log_to_txt: str | None = None, log_to_csv: str | None = None, log_summary_to_json: str | None = None, csv_coloring_style: str = 'file', log_to_visdom: dict | None = None, log_to_wandb: dict | None = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Dict | None = None)[source]
Bases:
object
Attack arguments to be passed to
Attacker
.- Parameters:
num_examples (
int
, ‘optional`, defaults to10
) – The number of examples to attack.-1
for entire dataset.num_successful_examples (
int
, optional, defaults toNone
) –The number of successful adversarial examples we want. This is different from
num_examples
asnum_examples
only cares about attacking N samples whilenum_successful_examples
aims to keep attacking until we have N successful cases. .. note:If set, this argument overrides `num_examples` argument.
( (num_examples_offset) – obj: int, optional, defaults to
0
): The offset index to start at in the dataset.attack_n (
bool
, optional, defaults toFalse
) – Whether to run attack until total of N examples have been attacked (and not skipped).shuffle (
bool
, optional, defaults toFalse
) – IfTrue
, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling the dataset internally and opts for shuffling the list of indices of examples we want to attack. This meansshuffle
can now be used with checkpoint saving.query_budget (
int
, optional, defaults toNone
) –The maximum number of model queries allowed per example attacked. If not set, we use the query budget set in the
GoalFunction
object (which by default isfloat("inf")
). .. note:Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
checkpoint_interval (
int
, optional, defaults toNone
) – If set, checkpoint will be saved after attacking every N examples. IfNone
is passed, no checkpoints will be saved.checkpoint_dir (
str
, optional, defaults to"checkpoints"
) – The directory to save checkpoint files.random_seed (
int
, optional, defaults to765
) – Random seed for reproducibility.parallel (
False
, optional, defaults toFalse
) – IfTrue
, run attack using multiple CPUs/GPUs.num_workers_per_device (
int
, optional, defaults to1
) – Number of worker processes to run per device in parallel mode (i.e.parallel=True
). For example, if you are using GPUs andnum_workers_per_device=2
, then 2 processes will be running in each GPU.log_to_txt (
str
, optional, defaults toNone
) – If set, save attack logs as a .txt file to the directory specified by this argument. If the last part of the provided path ends with .txt extension, it is assumed to the desired path of the log file.log_to_csv (
str
, optional, defaults toNone
) – If set, save attack logs as a CSV file to the directory specified by this argument. If the last part of the provided path ends with .csv extension, it is assumed to the desired path of the log file.csv_coloring_style (
str
, optional, defaults to"file"
) – Method for choosing how to mark perturbed parts of the text. Options are"file"
,"plain"
, and"html"
."file"
wraps perturbed parts with double brackets[[ <text> ]]
while"plain"
does not mark the text in any way.log_to_visdom (
dict
, optional, defaults toNone
) – If set, Visdom logger is used with the provided dictionary passed as a keyword arguments toVisdomLogger
. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following three keys and their corresponding values:"env", "port", "hostname"
.log_to_wandb (
dict
, optional, defaults toNone
) – If set, WandB logger is used with the provided dictionary passed as a keyword arguments toWeightsAndBiasesLogger
. Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following key and its corresponding value:"project"
.disable_stdout (
bool
, optional, defaults toFalse
) – Disable displaying individual attack results to stdout.silent (
bool
, optional, defaults toFalse
) – Disable all logging (except for errors). This is stronger thandisable_stdout
.enable_advance_metrics (
bool
, optional, defaults toFalse
) – Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
- classmethod create_loggers_from_args(args)[source]
Creates AttackLogManager from an AttackArgs object.
- attack_n: bool = False
- checkpoint_dir: str = 'checkpoints'
- checkpoint_interval: int = None
- csv_coloring_style: str = 'file'
- disable_stdout: bool = False
- enable_advance_metrics: bool = False
- log_summary_to_json: str = None
- log_to_csv: str = None
- log_to_txt: str = None
- log_to_visdom: dict = None
- log_to_wandb: dict = None
- metrics: Dict | None = None
- num_examples: int = 10
- num_examples_offset: int = 0
- num_successful_examples: int = None
- num_workers_per_device: int = 1
- parallel: bool = False
- query_budget: int = None
- random_seed: int = 765
- shuffle: bool = False
- silent: bool = False
- class textattack.attack_args.CommandLineAttackArgs(model: str = None, model_from_file: str = None, model_from_huggingface: str = None, dataset_by_model: str = None, dataset_from_huggingface: str = None, dataset_from_file: str = None, dataset_split: str = None, filter_by_labels: list = None, transformation: str = 'word-swap-embedding', constraints: list = <factory>, goal_function: str = 'untargeted-classification', search_method: str = 'greedy-word-wir', attack_recipe: str = None, attack_from_file: str = None, interactive: bool = False, parallel: bool = False, model_batch_size: int = 32, model_cache_size: int = 262144, constraint_cache_size: int = 262144, num_examples: int = 10, num_successful_examples: int = None, num_examples_offset: int = 0, attack_n: bool = False, shuffle: bool = False, query_budget: int = None, checkpoint_interval: int = None, checkpoint_dir: str = 'checkpoints', random_seed: int = 765, num_workers_per_device: int = 1, log_to_txt: str = None, log_to_csv: str = None, log_summary_to_json: str = None, csv_coloring_style: str = 'file', log_to_visdom: dict = None, log_to_wandb: dict = None, disable_stdout: bool = False, silent: bool = False, enable_advance_metrics: bool = False, metrics: Union[Dict, NoneType] = None)[source]
Bases:
AttackArgs
,_CommandLineAttackArgs
,DatasetArgs
,ModelArgs
Attacker Class
- class textattack.attacker.Attacker(attack, dataset, attack_args=None)[source]
Bases:
object
Class for running attacks on a dataset with specified parameters. This class uses the
Attack
to actually run the attacks, while also providing useful features such as parallel processing, saving/resuming from a checkpint, logging to files and stdout.- Parameters:
attack (
Attack
) –Attack
used to actually carry out the attack.dataset (
Dataset
) – Dataset to attack.attack_args (
AttackArgs
) – Arguments for attacking the dataset. For default settings, look at the AttackArgs class.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) >>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval >>> attack_args = textattack.AttackArgs( ... num_examples=20, ... log_to_csv="log.csv", ... checkpoint_interval=5, ... checkpoint_dir="checkpoints", ... disable_stdout=True ... ) >>> attacker = textattack.Attacker(attack, dataset, attack_args) >>> attacker.attack_dataset()
- attack_dataset()[source]
Attack the dataset.
- Returns:
list[AttackResult]
- List ofAttackResult
obtained after attacking the given dataset..
- classmethod from_checkpoint(attack, dataset, checkpoint)[source]
Resume attacking from a saved checkpoint. Attacker and dataset must be recovered by the user again, while attack args are loaded from the saved checkpoint.
- update_attack_args(**kwargs)[source]
To update any attack args, pass the new argument as keyword argument to this function.
Examples:
>>> attacker = #some instance of Attacker >>> # To switch to parallel mode and increase checkpoint interval from 100 to 500 >>> attacker.update_attack_args(parallel=True, checkpoint_interval=500)
- textattack.attacker.attack_from_queue(attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue)[source]
AugmenterArgs Class
- class textattack.augment_args.AugmenterArgs(input_csv: str, output_csv: str, input_column: str, recipe: str = 'embedding', pct_words_to_swap: float = 0.1, transformations_per_example: int = 2, random_seed: int = 42, exclude_original: bool = False, overwrite: bool = False, interactive: bool = False, fast_augment: bool = False, high_yield: bool = False, enable_advanced_metrics: bool = False)[source]
Bases:
object
Arguments for performing data augmentation.
- Parameters:
input_csv (str) – Path of input CSV file to augment.
output_csv (str) – Path of CSV file to output augmented data.
- enable_advanced_metrics: bool = False
- exclude_original: bool = False
- fast_augment: bool = False
- high_yield: bool = False
- input_column: str
- input_csv: str
- interactive: bool = False
- output_csv: str
- overwrite: bool = False
- pct_words_to_swap: float = 0.1
- random_seed: int = 42
- recipe: str = 'embedding'
- transformations_per_example: int = 2
DatasetArgs Class
- class textattack.dataset_args.DatasetArgs(dataset_by_model: str | None = None, dataset_from_huggingface: str | None = None, dataset_from_file: str | None = None, dataset_split: str | None = None, filter_by_labels: list | None = None)[source]
Bases:
object
Arguments for loading dataset from command line input.
- dataset_by_model: str = None
- dataset_from_file: str = None
- dataset_from_huggingface: str = None
- dataset_split: str = None
- filter_by_labels: list = None
ModelArgs Class
- class textattack.model_args.ModelArgs(model: str | None = None, model_from_file: str | None = None, model_from_huggingface: str | None = None)[source]
Bases:
object
Arguments for loading base/pretrained or trained models.
- model: str = None
- model_from_file: str = None
- model_from_huggingface: str = None
Trainer Class
- class textattack.trainer.Trainer(model_wrapper, task_type='classification', attack=None, train_dataset=None, eval_dataset=None, training_args=None)[source]
Bases:
object
Trainer is training and eval loop for adversarial training.
It is designed to work with PyTorch and Transformers models.
- Parameters:
model_wrapper (
ModelWrapper
) – Model wrapper containing both the model and the tokenizer.task_type (
str
, optional, defaults to"classification"
) – The task that the model is trained to perform. Currently,Trainer
supports two tasks: (1)"classification"
, (2)"regression"
.attack (
Attack
) –Attack
used to generate adversarial examples for training.train_dataset (
Dataset
) – Dataset for training.eval_dataset (
Dataset
) – Dataset for evaluationtraining_args (
TrainingArgs
) – Arguments for training.
Example:
>>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # We only use DeepWordBugGao2018 to demonstration purposes. >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). >>> training_args = textattack.TrainingArgs( ... num_epochs=3, ... num_clean_epochs=1, ... num_train_adv_examples=1000, ... learning_rate=5e-5, ... per_device_train_batch_size=8, ... gradient_accumulation_steps=4, ... log_to_tb=True, ... ) >>> trainer = textattack.Trainer( ... model_wrapper, ... "classification", ... attack, ... train_dataset, ... eval_dataset, ... training_args ... ) >>> trainer.train()
Note
When using
Trainer
with parallel=True inTrainingArgs
, make sure to protect the “entry point” of the program by usingif __name__ == '__main__':
. If not, each worker process used for generating adversarial examples will execute the training code again.- evaluate_step(model, tokenizer, batch)[source]
Perform a single evaluation step on a batch of inputs.
- Parameters:
model (
torch.nn.Module
) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor]
) –By default, this will be a tuple of input texts and target tensors.
Note
If you override the
get_eval_dataloader()
method, then shape/type ofbatch
will depend on how you created your batch.
- Returns:
tuple[torch.Tensor, torch.Tensor]
wherepreds:
torch.FloatTensor
of model’s prediction for the batch.targets:
torch.Tensor
of model’s targets (e.g. labels, target values).
- get_eval_dataloader(dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoader
for evaluation.- Parameters:
dataset (
Dataset
) – Dataset to use for evaluation.batch_size (
int
) – Batch size for evaluation.
- Returns:
torch.utils.data.DataLoader
- get_optimizer_and_scheduler(model, num_training_steps)[source]
Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return
None
for scheduler.- Parameters:
model (
torch.nn.Module
) – Model to be trained. Pass its parameters to optimizer for training.num_training_steps (
int
) – Number of total training steps.
- Returns:
Tuple of optimizer and scheduler
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]
- get_train_dataloader(dataset, adv_dataset, batch_size)[source]
Returns the
torch.utils.data.DataLoader
for training.
- training_step(model, tokenizer, batch)[source]
Perform a single training step on a batch of inputs.
- Parameters:
model (
torch.nn.Module
) – Model to train.tokenizer – Tokenizer used to tokenize input text.
batch (
tuple[list[str], torch.Tensor, torch.Tensor]
) –By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
Note
If you override the
get_train_dataloader()
method, then shape/type ofbatch
will depend on how you created your batch.
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
whereloss:
torch.FloatTensor
of shape 1 containing the loss.preds:
torch.FloatTensor
of model’s prediction for the batch.targets:
torch.Tensor
of model’s targets (e.g. labels, target values).
TrainingArgs Class
- class textattack.training_args.CommandLineTrainingArgs(model_name_or_path: str, attack: str, dataset: str, task_type: str = 'classification', model_max_length: int = None, model_num_labels: int = None, dataset_train_split: str = None, dataset_eval_split: str = None, filter_train_by_labels: list = None, filter_eval_by_labels: list = None, num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int = None, learning_rate: float = 5e-05, num_warmup_steps: Union[int, float] = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: Union[int, float] = -1, query_budget_train: int = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int = None, checkpoint_interval_epochs: int = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
TrainingArgs
,_CommandLineTrainingArgs
- output_dir: str
- class textattack.training_args.TrainingArgs(num_epochs: int = 3, num_clean_epochs: int = 1, attack_epoch_interval: int = 1, early_stopping_epochs: int | None = None, learning_rate: float = 5e-05, num_warmup_steps: int | float = 500, weight_decay: float = 0.01, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 32, gradient_accumulation_steps: int = 1, random_seed: int = 786, parallel: bool = False, load_best_model_at_end: bool = False, alpha: float = 1.0, num_train_adv_examples: int | float = -1, query_budget_train: int | None = None, attack_num_workers_per_device: int = 1, output_dir: str = <factory>, checkpoint_interval_steps: int | None = None, checkpoint_interval_epochs: int | None = None, save_last: bool = True, log_to_tb: bool = False, tb_log_dir: str | None = None, log_to_wandb: bool = False, wandb_project: str = 'textattack', logging_interval_step: int = 1)[source]
Bases:
object
Arguments for
Trainer
class that is used for adversarial training.- Parameters:
num_epochs (
int
, optional, defaults to3
) – Total number of epochs for training.num_clean_epochs (
int
, optional, defaults to1
) – Number of epochs to train on just the original training dataset before adversarial training.attack_epoch_interval (
int
, optional, defaults to1
) – Generate a new adversarial training set every N epochs.early_stopping_epochs (
int
, optional, defaults toNone
) – Number of epochs validation must increase before stopping early (None
for no early stopping).learning_rate (
float
, optional, defaults to5e-5
) – Learning rate for optimizer.num_warmup_steps (
int
orfloat
, optional, defaults to500
) – The number of steps for the warmup phase of linear scheduler. Ifnum_warmup_steps
is afloat
between 0 and 1, the number of warmup steps will bemath.ceil(num_training_steps * num_warmup_steps)
.weight_decay (
float
, optional, defaults to0.01
) – Weight decay (L2 penalty).per_device_train_batch_size (
int
, optional, defaults to8
) – The batch size per GPU/CPU for training.per_device_eval_batch_size (
int
, optional, defaults to32
) – The batch size per GPU/CPU for evaluation.gradient_accumulation_steps (
int
, optional, defaults to1
) – Number of updates steps to accumulate the gradients before performing a backward/update pass.random_seed (
int
, optional, defaults to786
) – Random seed for reproducibility.parallel (
bool
, optional, defaults toFalse
) – IfTrue
, train using multiple GPUs usingtorch.DataParallel
.load_best_model_at_end (
bool
, optional, defaults toFalse
) – IfTrue
, keep track of the best model across training and load it at the end.alpha (
float
, optional, defaults to1.0
) – The weight for adversarial loss.num_train_adv_examples (
int
orfloat
, optional, defaults to-1
) – The number of samples to successfully attack when generating adversarial training set before start of every epoch. Ifnum_train_adv_examples
is afloat
between 0 and 1, the number of adversarial examples generated is fraction of the original training set.query_budget_train (
int
, optional, defaults toNone
) – The max query budget to use when generating adversarial training set.None
means infinite query budget.attack_num_workers_per_device (
int
, defaults to optional,1
) – Number of worker processes to run per device for attack. Same asnum_workers_per_device
argument forAttackArgs
.output_dir (
str
, optional) – Directory to output training logs and checkpoints. Defaults to/outputs/%Y-%m-%d-%H-%M-%S-%f
format.checkpoint_interval_steps (
int
, optional, defaults toNone
) – If set, save model checkpoint after every N updates to the model.checkpoint_interval_epochs (
int
, optional, defaults toNone
) – If set, save model checkpoint after every N epochs.save_last (
bool
, optional, defaults toTrue
) – IfTrue
, save the model at end of training. Can be used withload_best_model_at_end
to save the best model at the end.log_to_tb (
bool
, optional, defaults toFalse
) – IfTrue
, log to Tensorboard.tb_log_dir (
str
, optional, defaults to"./runs"
) – Path of Tensorboard log directory.log_to_wandb (
bool
, optional, defaults toFalse
) – IfTrue
, log to Wandb.wandb_project (
str
, optional, defaults to"textattack"
) – Name of Wandb project for logging.logging_interval_step (
int
, optional, defaults to1
) – Log to Tensorboard/Wandb every N training steps.
- alpha: float = 1.0
- attack_epoch_interval: int = 1
- attack_num_workers_per_device: int = 1
- checkpoint_interval_epochs: int = None
- checkpoint_interval_steps: int = None
- early_stopping_epochs: int = None
- gradient_accumulation_steps: int = 1
- learning_rate: float = 5e-05
- load_best_model_at_end: bool = False
- log_to_tb: bool = False
- log_to_wandb: bool = False
- logging_interval_step: int = 1
- num_clean_epochs: int = 1
- num_epochs: int = 3
- num_train_adv_examples: int | float = -1
- num_warmup_steps: int | float = 500
- output_dir: str
- parallel: bool = False
- per_device_eval_batch_size: int = 32
- per_device_train_batch_size: int = 8
- query_budget_train: int = None
- random_seed: int = 786
- save_last: bool = True
- tb_log_dir: str = None
- wandb_project: str = 'textattack'
- weight_decay: float = 0.01