Source code for textattack.constraints.grammaticality.cola

"""
CoLA for Grammaticality
--------------------------

"""

import lru
import nltk
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from textattack.constraints import Constraint
from textattack.models.wrappers import HuggingFaceModelWrapper


[docs]class COLA(Constraint): """Constrains an attack to text that has a similar number of linguistically accecptable sentences as the original text. Linguistic acceptability is determined by a model pre-trained on the `CoLA dataset <https://nyu- mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre- trained models README <https://github.com/QData/TextAttack/tree/master/ textattack/models>`_ for a full list of available models or provide your own model from the huggingface model hub. Args: max_diff (float or int): The absolute (if int or greater than or equal to 1) or percent (if float and less than 1) maximum difference allowed between the number of valid sentences in the reference text and the number of valid sentences in the attacked text. model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub. compare_against_original (bool): If `True`, compare against the original text. Otherwise, compare against the most recent text. """ def __init__( self, max_diff, model_name="textattack/bert-base-uncased-CoLA", compare_against_original=True, ): super().__init__(compare_against_original) if not isinstance(max_diff, float) and not isinstance(max_diff, int): raise TypeError("max_diff must be a float or int") if max_diff < 0.0: raise ValueError("max_diff must be a value greater or equal to than 0.0") self.max_diff = max_diff self.model_name = model_name self._reference_score_cache = lru.LRU(2**10) model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = HuggingFaceModelWrapper(model, tokenizer)
[docs] def clear_cache(self): self._reference_score_cache.clear()
def _check_constraint(self, transformed_text, reference_text): if reference_text not in self._reference_score_cache: # Split the text into sentences before predicting validity reference_sentences = nltk.sent_tokenize(reference_text.text) # A label of 1 indicates the sentence is valid num_valid = self.model(reference_sentences).argmax(axis=1).sum() self._reference_score_cache[reference_text] = num_valid sentences = nltk.sent_tokenize(transformed_text.text) predictions = self.model(sentences) num_valid = predictions.argmax(axis=1).sum() reference_score = self._reference_score_cache[reference_text] if isinstance(self.max_diff, int) or self.max_diff >= 1: threshold = reference_score - self.max_diff else: threshold = reference_score - (reference_score * self.max_diff) if num_valid < threshold: return False return True
[docs] def extra_repr_keys(self): return [ "max_diff", "model_name", ] + super().extra_repr_keys()
def __getstate__(self): state = self.__dict__.copy() state["_reference_score_cache"] = self._reference_score_cache.get_size() return state def __setstate__(self, state): self.__dict__ = state self._reference_score_cache = lru.LRU(state["_reference_score_cache"])