Source code for textattack.constraints.grammaticality.language_models.language_model_constraint

"""
Language Models Constraint
---------------------------

"""

from abc import ABC, abstractmethod

from textattack.constraints import Constraint


[docs]class LanguageModelConstraint(Constraint, ABC): """Determines if two sentences have a swapped word that has a similar probability according to a language model. Args: max_log_prob_diff (float): the maximum decrease in log-probability in swapped words from `x` to `x_adv` compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`. """ def __init__(self, max_log_prob_diff=None, compare_against_original=True): if max_log_prob_diff is None: raise ValueError("Must set max_log_prob_diff") self.max_log_prob_diff = max_log_prob_diff super().__init__(compare_against_original)
[docs] @abstractmethod def get_log_probs_at_index(self, text_list, word_index): """Gets the log-probability of items in `text_list` at index `word_index` according to a language model.""" raise NotImplementedError()
def _check_constraint(self, transformed_text, reference_text): try: indices = transformed_text.attack_attrs["newly_modified_indices"] except KeyError: raise KeyError( "Cannot apply language model constraint without `newly_modified_indices`" ) for i in indices: probs = self.get_log_probs_at_index((reference_text, transformed_text), i) if len(probs) != 2: raise ValueError( f"Error: get_log_probs_at_index returned {len(probs)} values for 2 inputs" ) ref_prob, transformed_prob = probs if transformed_prob <= ref_prob - self.max_log_prob_diff: return False return True
[docs] def extra_repr_keys(self): return ["max_log_prob_diff"] + super().extra_repr_keys()