Source code for textattack.augmentation.augmenter

"""
Augmenter Class
===================
"""
import random

import tqdm

from textattack.constraints import PreTransformationConstraint
from textattack.metrics.quality_metrics import Perplexity, USEMetric
from textattack.shared import AttackedText, utils


[docs]class Augmenter: """A class for performing data augmentation using TextAttack. Returns all possible transformations for a given string. Currently only supports transformations which are word swaps. Args: transformation (textattack.Transformation): the transformation that suggests new texts from an input. constraints: (list(textattack.Constraint)): constraints that each transformation must meet pct_words_to_swap: (float): [0., 1.], percentage of words to swap per augmented example transformations_per_example: (int): Maximum number of augmentations per input high_yield: Whether to return a set of augmented texts that will be relatively similar, or to return only a single one. fast_augment: Stops additional transformation runs when number of successful augmentations reaches transformations_per_example advanced_metrics: return perplexity and USE Score of augmentation Example:: >>> from textattack.transformations import WordSwapRandomCharacterDeletion, WordSwapQWERTY, CompositeTransformation >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification >>> from textattack.augmentation import Augmenter >>> transformation = CompositeTransformation([WordSwapRandomCharacterDeletion(), WordSwapQWERTY()]) >>> constraints = [RepeatModification(), StopwordModification()] >>> # initiate augmenter >>> augmenter = Augmenter( ... transformation=transformation, ... constraints=constraints, ... pct_words_to_swap=0.5, ... transformations_per_example=3 ... ) >>> # additional parameters can be modified if not during initiation >>> augmenter.enable_advanced_metrics = True >>> augmenter.fast_augment = True >>> augmenter.high_yield = True >>> s = 'What I cannot create, I do not understand.' >>> results = augmenter.augment(s) >>> augmentations = results[0] >>> perplexity_score = results[1] >>> use_score = results[2] """ def __init__( self, transformation, constraints=[], pct_words_to_swap=0.1, transformations_per_example=1, high_yield=False, fast_augment=False, enable_advanced_metrics=False, ): assert ( transformations_per_example > 0 ), "transformations_per_example must be a positive integer" assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]" self.transformation = transformation self.pct_words_to_swap = pct_words_to_swap self.transformations_per_example = transformations_per_example self.constraints = [] self.pre_transformation_constraints = [] self.high_yield = high_yield self.fast_augment = fast_augment self.advanced_metrics = enable_advanced_metrics for constraint in constraints: if isinstance(constraint, PreTransformationConstraint): self.pre_transformation_constraints.append(constraint) else: self.constraints.append(constraint) def _filter_transformations(self, transformed_texts, current_text, original_text): """Filters a list of ``AttackedText`` objects to include only the ones that pass ``self.constraints``.""" for C in self.constraints: if len(transformed_texts) == 0: break if C.compare_against_original: if not original_text: raise ValueError( f"Missing `original_text` argument when constraint {type(C)} is set to compare against " f"`original_text` " ) transformed_texts = C.call_many(transformed_texts, original_text) else: transformed_texts = C.call_many(transformed_texts, current_text) return transformed_texts
[docs] def augment(self, text): """Returns all possible augmentations of ``text`` according to ``self.transformation``.""" attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() num_words_to_swap = max( int(self.pct_words_to_swap * len(attacked_text.words)), 1 ) augmentation_results = [] for _ in range(self.transformations_per_example): current_text = attacked_text words_swapped = len(current_text.attack_attrs["modified_indices"]) while words_swapped < num_words_to_swap: transformed_texts = self.transformation( current_text, self.pre_transformation_constraints ) # Get rid of transformations we already have transformed_texts = [ t for t in transformed_texts if t not in all_transformed_texts ] # Filter out transformations that don't match the constraints. transformed_texts = self._filter_transformations( transformed_texts, current_text, original_text ) # if there's no more transformed texts after filter, terminate if not len(transformed_texts): break # look for all transformed_texts that has enough words swapped if self.high_yield or self.fast_augment: ready_texts = [ text for text in transformed_texts if len(text.attack_attrs["modified_indices"]) >= num_words_to_swap ] for text in ready_texts: all_transformed_texts.add(text) unfinished_texts = [ text for text in transformed_texts if text not in ready_texts ] if len(unfinished_texts): current_text = random.choice(unfinished_texts) else: # no need for further augmentations if all of transformed_texts meet `num_words_to_swap` break else: current_text = random.choice(transformed_texts) # update words_swapped based on modified indices words_swapped = max( len(current_text.attack_attrs["modified_indices"]), words_swapped + 1, ) all_transformed_texts.add(current_text) # when with fast_augment, terminate early if there're enough successful augmentations if ( self.fast_augment and len(all_transformed_texts) >= self.transformations_per_example ): if not self.high_yield: all_transformed_texts = random.sample( all_transformed_texts, self.transformations_per_example ) break perturbed_texts = sorted([at.printable_text() for at in all_transformed_texts]) if self.advanced_metrics: for transformed_texts in all_transformed_texts: augmentation_results.append( AugmentationResult(original_text, transformed_texts) ) perplexity_stats = Perplexity().calculate(augmentation_results) use_stats = USEMetric().calculate(augmentation_results) return perturbed_texts, perplexity_stats, use_stats return perturbed_texts
[docs] def augment_many(self, text_list, show_progress=False): """Returns all possible augmentations of a list of strings according to ``self.transformation``. Args: text_list (list(string)): a list of strings for data augmentation Returns a list(string) of augmented texts. :param show_progress: show process during augmentation """ if show_progress: text_list = tqdm.tqdm(text_list, desc="Augmenting data...") return [self.augment(text) for text in text_list]
[docs] def augment_text_with_ids(self, text_list, id_list, show_progress=True): """Supplements a list of text with more text data. Returns the augmented text along with the corresponding IDs for each augmented example. """ if len(text_list) != len(id_list): raise ValueError("List of text must be same length as list of IDs") if self.transformations_per_example == 0: return text_list, id_list all_text_list = [] all_id_list = [] if show_progress: text_list = tqdm.tqdm(text_list, desc="Augmenting data...") for text, _id in zip(text_list, id_list): all_text_list.append(text) all_id_list.append(_id) augmented_texts = self.augment(text) all_text_list.extend all_text_list.extend([text] + augmented_texts) all_id_list.extend([_id] * (1 + len(augmented_texts))) return all_text_list, all_id_list
def __repr__(self): main_str = "Augmenter" + "(" lines = [] # self.transformation lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2)) # self.constraints constraints_lines = [] constraints = self.constraints + self.pre_transformation_constraints if len(constraints): for i, constraint in enumerate(constraints): constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2)) constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2) else: constraints_str = "None" lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2)) main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str
[docs]class AugmentationResult: def __init__(self, text1, text2): self.original_result = self.tempResult(text1) self.perturbed_result = self.tempResult(text2)
[docs] class tempResult: def __init__(self, text): self.attacked_text = text