Source code for textattack.commands.augment_command

"""

AugmentCommand class
===========================

"""

from argparse import ArgumentDefaultsHelpFormatter, ArgumentError, ArgumentParser
import csv
import os
import time

import tqdm

import textattack
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES
from textattack.commands import TextAttackCommand


[docs]class AugmentCommand(TextAttackCommand): """The TextAttack attack module: A command line parser to run data augmentation from user specifications. """
[docs] def run(self, args): """Reads in a CSV, performs augmentation, and outputs an augmented CSV. Preserves all columns except for the input (augmneted) column. """ args = textattack.AugmenterArgs(**vars(args)) if args.interactive: print("\nRunning in interactive mode...\n") augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( pct_words_to_swap=args.pct_words_to_swap, transformations_per_example=args.transformations_per_example, high_yield=args.high_yield, fast_augment=args.fast_augment, enable_advanced_metrics=args.enable_advanced_metrics, ) print("--------------------------------------------------------") while True: print( '\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n' ) text = input() if text == "q": break elif text == "c": print( f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, " f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, " f"\n\t transformations_per_example: {args.transformations_per_example}\n" ) change = input( "Enter 'c' again to change arguments, any other keys to opt out\n" ) if change == "c": print("\nChanging augmenter arguments...\n") recipe = input( "\tAugmentation recipe name ('r' to see available recipes): " ) if recipe == "r": recipe_display = " ".join(AUGMENTATION_RECIPE_NAMES.keys()) print(f"\n\t{recipe_display}\n") args.recipe = input("\tAugmentation recipe name: ") else: args.recipe = recipe args.pct_words_to_swap = float( input("\tPercentage of words to swap (0.0 ~ 1.0): ") ) args.transformations_per_example = int( input("\tTransformations per input example: ") ) print("\nGenerating new augmenter...\n") augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( pct_words_to_swap=args.pct_words_to_swap, transformations_per_example=args.transformations_per_example, ) print( "--------------------------------------------------------" ) continue elif not text: continue print("\nAugmenting...\n") print("--------------------------------------------------------") if args.enable_advanced_metrics: results = augmenter.augment(text) print("Augmentations:\n") for augmentation in results[0]: print(augmentation, "\n") print() print( f"Average Original Perplexity Score: {results[1]['avg_original_perplexity']}" ) print( f"Average Augment Perplexity Score: {results[1]['avg_attack_perplexity']}" ) print( f"Average Augment USE Score: {results[2]['avg_attack_use_score']}\n" ) else: for augmentation in augmenter.augment(text): print(augmentation, "\n") print("--------------------------------------------------------") else: textattack.shared.utils.set_seed(args.random_seed) start_time = time.time() if not (args.input_csv and args.input_column and args.output_csv): raise ArgumentError( "The following arguments are required: --csv, --input-column/--i" ) # Validate input/output paths. if not os.path.exists(args.input_csv): raise FileNotFoundError(f"Can't find CSV at location {args.input_csv}") if os.path.exists(args.output_csv): if args.overwrite: textattack.shared.logger.info( f"Preparing to overwrite {args.output_csv}." ) else: raise OSError( f"Outfile {args.output_csv} exists and --overwrite not set." ) # Read in CSV file as a list of dictionaries. Use the CSV sniffer to # try and automatically infer the correct CSV format. csv_file = open(args.input_csv, "r") # mark where commas and quotes occur within the text value def markQuotes(lines): for row in lines: row = row.replace('"', '"/') yield row dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,") csv_file.seek(0) rows = [ row for row in csv.DictReader( markQuotes(csv_file), dialect=dialect, skipinitialspace=True, ) ] # replace markings with quotations and commas for row in rows: for item in row: i = 0 while i < len(row[item]): if row[item][i] == "/": if row[item][i - 1] == '"': row[item] = row[item][:i] + row[item][i + 1 :] else: row[item] = row[item][:i] + '"' + row[item][i + 1 :] i += 1 # Validate input column. row_keys = set(rows[0].keys()) if args.input_column not in row_keys: raise ValueError( f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}" ) textattack.shared.logger.info( f"Read {len(rows)} rows from {args.input_csv}. Found columns {row_keys}." ) augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( pct_words_to_swap=args.pct_words_to_swap, transformations_per_example=args.transformations_per_example, high_yield=args.high_yield, fast_augment=args.fast_augment, ) output_rows = [] for row in tqdm.tqdm(rows, desc="Augmenting rows"): text_input = row[args.input_column] if not args.exclude_original: output_rows.append(row) for augmentation in augmenter.augment(text_input): augmented_row = row.copy() augmented_row[args.input_column] = augmentation output_rows.append(augmented_row) # Print to file. with open(args.output_csv, "w") as outfile: csv_writer = csv.writer( outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL ) # Write header. csv_writer.writerow(output_rows[0].keys()) # Write rows. for row in output_rows: csv_writer.writerow(row.values()) textattack.shared.logger.info( f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s." ) # Remove extra markings in output file with open(args.output_csv, "r") as file: data = file.readlines() for i in range(len(data)): data[i] = data[i].replace("/", "") with open(args.output_csv, "w") as file: file.writelines(data)
[docs] @staticmethod def register_subcommand(main_parser: ArgumentParser): parser = main_parser.add_parser( "augment", help="augment text data", formatter_class=ArgumentDefaultsHelpFormatter, ) parser = textattack.AugmenterArgs._add_parser_args(parser) parser.set_defaults(func=AugmentCommand())