textattack.metrics.quality_metrics package
Metrics on Quality package
TextAttack provide users common metrics on text examples’ quality.
Perplexity Metric:
Class for calculating perplexity from AttackResults
- class textattack.metrics.quality_metrics.perplexity.Perplexity(model_name='gpt2')[source]
Bases:
Metric
- calculate(results)[source]
Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model.
- Parameters:
results (
AttackResult
objects) – Attack results for each instance in dataset
Example:
>> import textattack >> import transformers >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") >> attack_args = textattack.AttackArgs( num_examples=1, log_to_csv="log.csv", checkpoint_interval=5, checkpoint_dir="checkpoints", disable_stdout=True ) >> attacker = textattack.Attacker(attack, dataset, attack_args) >> results = attacker.attack_dataset() >> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)
USEMetric class:
Class for calculating USE similarity on AttackResults
- class textattack.metrics.quality_metrics.use.USEMetric(**kwargs)[source]
Bases:
Metric
- calculate(results)[source]
Calculates average USE similarity on all successfull attacks.
- Parameters:
results (
AttackResult
objects) – Attack results for each instance in dataset
Example:
>> import textattack >> import transformers >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") >> attack_args = textattack.AttackArgs( num_examples=1, log_to_csv="log.csv", checkpoint_interval=5, checkpoint_dir="checkpoints", disable_stdout=True ) >> attacker = textattack.Attacker(attack, dataset, attack_args) >> results = attacker.attack_dataset() >> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results)