Source code for textattack.goal_functions.classification.targeted_classification

"""

Determine if an attack has been successful in targeted Classification
-----------------------------------------------------------------------
"""

from .classification_goal_function import ClassificationGoalFunction


[docs]class TargetedClassification(ClassificationGoalFunction): """A targeted attack on classification models which attempts to maximize the score of the target label. Complete when the arget label is the predicted label. """ def __init__(self, *args, target_class=0, **kwargs): super().__init__(*args, **kwargs) self.target_class = target_class def _is_goal_complete(self, model_output, _): return ( self.target_class == model_output.argmax() ) or self.ground_truth_output == self.target_class def _get_score(self, model_output, _): if self.target_class < 0 or self.target_class >= len(model_output): raise ValueError( f"target class set to {self.target_class} with {len(model_output)} classes." ) else: return model_output[self.target_class]
[docs] def extra_repr_keys(self): if self.maximizable: return ["maximizable", "target_class"] else: return ["target_class"]