Source code for textattack.constraints.constraint

"""

TextAttack Constraint Class
=====================================
"""

from abc import ABC, abstractmethod

import textattack
from textattack.shared.utils import ReprMixin


[docs]class Constraint(ReprMixin, ABC): """An abstract class that represents constraints on adversial text examples. Constraints evaluate whether transformations from a ``AttackedText`` to another ``AttackedText`` meet certain conditions. Args: compare_against_original (bool): If `True`, the reference text should be the original text under attack. If `False`, the reference text is the most recent text from which the transformed text was generated. All constraints must have this attribute. """ def __init__(self, compare_against_original): self.compare_against_original = compare_against_original
[docs] def call_many(self, transformed_texts, reference_text): """Filters ``transformed_texts`` based on which transformations fulfill the constraint. First checks compatibility with latest ``Transformation``, then calls ``_check_constraint_many`` Args: transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s. reference_text (AttackedText): The ``AttackedText`` to compare against. """ incompatible_transformed_texts = [] compatible_transformed_texts = [] for transformed_text in transformed_texts: try: if self.check_compatibility( transformed_text.attack_attrs["last_transformation"] ): compatible_transformed_texts.append(transformed_text) else: incompatible_transformed_texts.append(transformed_text) except KeyError: raise KeyError( "transformed_text must have `last_transformation` attack_attr to apply constraint" ) filtered_texts = self._check_constraint_many( compatible_transformed_texts, reference_text ) return list(filtered_texts) + incompatible_transformed_texts
def _check_constraint_many(self, transformed_texts, reference_text): """Filters ``transformed_texts`` based on which transformations fulfill the constraint. Calls ``check_constraint`` Args: transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText`` reference_texts (AttackedText): The ``AttackedText`` to compare against. """ return [ transformed_text for transformed_text in transformed_texts if self._check_constraint(transformed_text, reference_text) ] def __call__(self, transformed_text, reference_text): """Returns True if the constraint is fulfilled, False otherwise. First checks compatibility with latest ``Transformation``, then calls ``_check_constraint`` Args: transformed_text (AttackedText): The candidate transformed ``AttackedText``. reference_text (AttackedText): The ``AttackedText`` to compare against. """ if not isinstance(transformed_text, textattack.shared.AttackedText): raise TypeError("transformed_text must be of type AttackedText") if not isinstance(reference_text, textattack.shared.AttackedText): raise TypeError("reference_text must be of type AttackedText") try: if not self.check_compatibility( transformed_text.attack_attrs["last_transformation"] ): return True except KeyError: raise KeyError( "`transformed_text` must have `last_transformation` attack_attr to apply constraint." ) return self._check_constraint(transformed_text, reference_text) @abstractmethod def _check_constraint(self, transformed_text, reference_text): """Returns True if the constraint is fulfilled, False otherwise. Must be overridden by the specific constraint. Args: transformed_text: The candidate transformed ``AttackedText``. reference_text (AttackedText): The ``AttackedText`` to compare against. """ raise NotImplementedError()
[docs] def check_compatibility(self, transformation): """Checks if this constraint is compatible with the given transformation. For example, the ``WordEmbeddingDistance`` constraint compares the embedding of the word inserted with that of the word deleted. Therefore it can only be applied in the case of word swaps, and not for transformations which involve only one of insertion or deletion. Args: transformation: The ``Transformation`` to check compatibility with. """ return True
[docs] def extra_repr_keys(self): """Set the extra representation of the constraint using these keys. To print customized extra information, you should reimplement this method in your own constraint. Both single-line and multi- line strings are acceptable. """ return ["compare_against_original"]