Source code for textattack.augment_args

"""
AugmenterArgs Class
===================
"""

from dataclasses import dataclass

AUGMENTATION_RECIPE_NAMES = {
    "wordnet": "textattack.augmentation.WordNetAugmenter",
    "embedding": "textattack.augmentation.EmbeddingAugmenter",
    "charswap": "textattack.augmentation.CharSwapAugmenter",
    "eda": "textattack.augmentation.EasyDataAugmenter",
    "checklist": "textattack.augmentation.CheckListAugmenter",
    "clare": "textattack.augmentation.CLAREAugmenter",
    "back_trans": "textattack.augmentation.BackTranslationAugmenter",
    "back_transcription": "textattack.augmentation.BackTranscriptionAugmenter",
}


[docs]@dataclass class AugmenterArgs: """Arguments for performing data augmentation. Args: input_csv (str): Path of input CSV file to augment. output_csv (str): Path of CSV file to output augmented data. """ input_csv: str output_csv: str input_column: str recipe: str = "embedding" pct_words_to_swap: float = 0.1 transformations_per_example: int = 2 random_seed: int = 42 exclude_original: bool = False overwrite: bool = False interactive: bool = False fast_augment: bool = False high_yield: bool = False enable_advanced_metrics: bool = False @classmethod def _add_parser_args(cls, parser): parser.add_argument( "--input-csv", type=str, help="Path of input CSV file to augment.", ) parser.add_argument( "--output-csv", type=str, help="Path of CSV file to output augmented data.", ) parser.add_argument( "--input-column", "--i", type=str, help="CSV input column to be augmented", ) parser.add_argument( "--recipe", "-r", help="Name of augmentation recipe", type=str, default="embedding", choices=AUGMENTATION_RECIPE_NAMES.keys(), ) parser.add_argument( "--pct-words-to-swap", "--p", help="Percentage of words to modify when generating each augmented example.", type=float, default=0.1, ) parser.add_argument( "--transformations-per-example", "--t", help="number of augmentations to return for each input", type=int, default=2, ) parser.add_argument( "--random-seed", default=42, type=int, help="random seed to set" ) parser.add_argument( "--exclude-original", default=False, action="store_true", help="exclude original example from augmented CSV", ) parser.add_argument( "--overwrite", default=False, action="store_true", help="overwrite output file, if it exists", ) parser.add_argument( "--interactive", default=False, action="store_true", help="Whether to run attacks interactively.", ) parser.add_argument( "--high_yield", default=False, action="store_true", help="run attacks with high yield.", ) parser.add_argument( "--fast_augment", default=False, action="store_true", help="faster augmentation but may use only a few transformations.", ) parser.add_argument( "--enable_advanced_metrics", default=False, action="store_true", help="return perplexity and USE score", ) return parser