Source code for textattack.shared.validators

"""
Misc Validators
=================
Validators ensure compatibility between search methods, transformations, constraints, and goal functions.

"""
import re

import textattack
from textattack.goal_functions import (
    InputReduction,
    MinimizeBleu,
    NonOverlappingOutput,
    TargetedClassification,
    UntargetedClassification,
)

from . import logger

# A list of goal functions and the corresponding available models.
MODELS_BY_GOAL_FUNCTIONS = {
    (TargetedClassification, UntargetedClassification, InputReduction): [
        r"^textattack.models.helpers.lstm_for_classification.*",
        r"^textattack.models.helpers.word_cnn_for_classification.*",
        r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
    ],
    (NonOverlappingOutput, MinimizeBleu,): [
        r"^textattack.models.helpers.t5_for_text_to_text.*",
    ],
}

# Unroll the `MODELS_BY_GOAL_FUNCTIONS` dictionary into a dictionary that has
# a key for each goal function. (Note the plurality here that distinguishes
# the two variables from one another.)
MODELS_BY_GOAL_FUNCTION = {}
for goal_functions, matching_model_globs in MODELS_BY_GOAL_FUNCTIONS.items():
    for goal_function in goal_functions:
        MODELS_BY_GOAL_FUNCTION[goal_function] = matching_model_globs


[docs]def validate_model_goal_function_compatibility(goal_function_class, model_class): """Determines if ``model_class`` is task-compatible with ``goal_function_class``. For example, a text-generative model like one intended for translation or summarization would not be compatible with a goal function that requires probability scores, like the UntargetedGoalFunction. """ # Verify that this is a valid goal function. try: matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class] except KeyError: matching_model_globs = [] logger.warn(f"No entry found for goal function {goal_function_class}.") # Get options for this goal function. # model_module = model_class.__module__ model_module_path = ".".join((model_class.__module__, model_class.__name__)) # Ensure the model matches one of these options. for glob in matching_model_globs: if re.match(glob, model_module_path): logger.info( f"Goal function {goal_function_class} compatible with model {model_class.__name__}." ) return # If we got here, the model does not match the intended goal function. for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items(): for glob in globs: if re.match(glob, model_module_path): logger.warn( f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}." f" Found match with other goal functions: {goal_functions}." ) return # If it matches another goal function, warn user. # Otherwise, this is an unknown model–perhaps user-provided, or we forgot to # update the corresponding dictionary. Warn user and return. logger.warn( f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}." )
[docs]def validate_model_gradient_word_swap_compatibility(model): """Determines if ``model`` is task-compatible with ``GradientBasedWordSwap``. We can only take the gradient with respect to an individual word if the model uses a word-based tokenizer. """ if isinstance(model, textattack.models.helpers.LSTMForClassification): return True else: raise ValueError(f"Cannot perform GradientBasedWordSwap on model {model}.")
[docs]def transformation_consists_of(transformation, transformation_classes): """Determines if ``transformation`` is or consists only of instances of a class in ``transformation_classes``""" from textattack.transformations import CompositeTransformation if isinstance(transformation, CompositeTransformation): for t in transformation.transformations: if not transformation_consists_of(t, transformation_classes): return False return True else: for transformation_class in transformation_classes: if isinstance(transformation, transformation_class): return True return False
[docs]def transformation_consists_of_word_swaps(transformation): """Determines if ``transformation`` is a word swap or consists of only word swaps.""" from textattack.transformations import WordSwap, WordSwapGradientBased return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased])
[docs]def transformation_consists_of_word_swaps_and_deletions(transformation): """Determines if ``transformation`` is a word swap or consists of only word swaps and deletions.""" from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased return transformation_consists_of( transformation, [WordDeletion, WordSwap, WordSwapGradientBased] )