textattack.metrics.quality_metrics package

Metrics on Quality package

TextAttack provide users common metrics on text examples’ quality.

Perplexity Metric:

Class for calculating perplexity from AttackResults

class textattack.metrics.quality_metrics.perplexity.Perplexity(model_name='gpt2')[source]

Bases: Metric

calc_ppl(texts)[source]
calculate(results)[source]

Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model.

Parameters:

results (AttackResult objects) – Attack results for each instance in dataset

Example:

>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
    num_examples=1,
    log_to_csv="log.csv",
    checkpoint_interval=5,
    checkpoint_dir="checkpoints",
    disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)

USEMetric class:

Class for calculating USE similarity on AttackResults

class textattack.metrics.quality_metrics.use.USEMetric(**kwargs)[source]

Bases: Metric

calculate(results)[source]

Calculates average USE similarity on all successfull attacks.

Parameters:

results (AttackResult objects) – Attack results for each instance in dataset

Example:

>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
    num_examples=1,
    log_to_csv="log.csv",
    checkpoint_interval=5,
    checkpoint_dir="checkpoints",
    disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results)