Source code for textattack.goal_functions.classification.hardlabel_classification

"""
Determine if an attack has been successful in Hard Label Classficiation.
------------------------------------------------------------------------
"""

from .classification_goal_function import ClassificationGoalFunction


[docs]class HardLabelClassification(ClassificationGoalFunction): """An hard label attack on classification models which attempts to maximize the semantic similarity of the label such that the target is outside of the decision boundary. Args: target_max_score (float): If set, goal is to reduce model output to below this score. Otherwise, goal is to change the overall predicted class. """ def __init__(self, *args, target_max_score=None, **kwargs): self.target_max_score = target_max_score super().__init__(*args, **kwargs) def _is_goal_complete(self, model_output, _): if self.target_max_score: return model_output[self.ground_truth_output] < self.target_max_score elif (model_output.numel() == 1) and isinstance( self.ground_truth_output, float ): return abs(self.ground_truth_output - model_output.item()) >= 0.5 else: return model_output.argmax() != self.ground_truth_output def _get_score(self, model_output, _): # If the model outputs a single number and the ground truth output is # a float, we assume that this is a regression task. if (model_output.numel() == 1) and isinstance(self.ground_truth_output, float): return max(model_output.item(), self.ground_truth_output) else: return 1 - model_output[self.ground_truth_output]