Source code for textattack.commands.eval_model_command

"""

EvalModelCommand class
==============================

"""

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass

import scipy
import torch

import textattack
from textattack import DatasetArgs, ModelArgs
from textattack.commands import TextAttackCommand
from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS

logger = textattack.shared.utils.logger


def _cb(s):
    return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")


[docs]@dataclass class ModelEvalArgs(ModelArgs, DatasetArgs): random_seed: int = 765 batch_size: int = 32 num_examples: int = 5 num_examples_offset: int = 0
[docs]class EvalModelCommand(TextAttackCommand): """The TextAttack model benchmarking module: A command line parser to evaluatate a model from user specifications. """
[docs] def get_preds(self, model, inputs): with torch.no_grad(): preds = textattack.shared.utils.batch_model_predict(model, inputs) return preds
[docs] def test_model_on_dataset(self, args): model = ModelArgs._create_model_from_args(args) dataset = DatasetArgs._create_dataset_from_args(args) if args.num_examples == -1: args.num_examples = len(dataset) preds = [] ground_truth_outputs = [] i = 0 while i < min(args.num_examples, len(dataset)): dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)] batch_inputs = [] for text_input, ground_truth_output in dataset_batch: attacked_text = textattack.shared.AttackedText(text_input) batch_inputs.append(attacked_text.tokenizer_input) ground_truth_outputs.append(ground_truth_output) batch_preds = model(batch_inputs) if not isinstance(batch_preds, torch.Tensor): batch_preds = torch.Tensor(batch_preds) preds.extend(batch_preds) i += args.batch_size preds = torch.stack(preds).squeeze().cpu() ground_truth_outputs = torch.tensor(ground_truth_outputs).cpu() logger.info(f"Got {len(preds)} predictions.") if preds.ndim == 1: # if preds is just a list of numbers, assume regression for now # TODO integrate with `textattack.metrics` package pearson_correlation, _ = scipy.stats.pearsonr(ground_truth_outputs, preds) spearman_correlation, _ = scipy.stats.spearmanr(ground_truth_outputs, preds) logger.info(f"Pearson correlation = {_cb(pearson_correlation)}") logger.info(f"Spearman correlation = {_cb(spearman_correlation)}") else: guess_labels = preds.argmax(dim=1) successes = (guess_labels == ground_truth_outputs).sum().item() perc_accuracy = successes / len(preds) * 100.0 perc_accuracy = "{:.2f}%".format(perc_accuracy) logger.info(f"Correct {successes}/{len(preds)} ({_cb(perc_accuracy)})")
[docs] def run(self, args): args = ModelEvalArgs(**vars(args)) textattack.shared.utils.set_seed(args.random_seed) # Default to 'all' if no model chosen. if not (args.model or args.model_from_huggingface or args.model_from_file): for model_name in list(HUGGINGFACE_MODELS.keys()) + list( TEXTATTACK_MODELS.keys() ): args.model = model_name self.test_model_on_dataset(args) logger.info("-" * 50) else: self.test_model_on_dataset(args)
[docs] @staticmethod def register_subcommand(main_parser: ArgumentParser): parser = main_parser.add_parser( "eval", help="evaluate a model with TextAttack", formatter_class=ArgumentDefaultsHelpFormatter, ) parser = ModelArgs._add_parser_args(parser) parser = DatasetArgs._add_parser_args(parser) parser.add_argument("--random-seed", default=765, type=int) parser.add_argument( "--batch-size", type=int, default=32, help="The batch size for evaluating the model.", ) parser.add_argument( "--num-examples", "-n", type=int, required=False, default=5, help="The number of examples to process, -1 for entire dataset", ) parser.add_argument( "--num-examples-offset", "-o", type=int, required=False, default=0, help="The offset to start at in the dataset.", ) parser.set_defaults(func=EvalModelCommand())