Source code for textattack.search_methods.genetic_algorithm

Genetic Algorithm Word Swap
from abc import ABC, abstractmethod

import numpy as np
import torch

from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import PopulationBasedSearch, PopulationMember
from textattack.shared.validators import transformation_consists_of_word_swaps

[docs]class GeneticAlgorithm(PopulationBasedSearch, ABC): """Base class for attacking a model with word substiutitions using a genetic algorithm. Args: pop_size (int): The population size. Defaults to 20. max_iters (int): The maximum number of iterations to use. Defaults to 50. temp (float): Temperature for softmax function used to normalize probability dist when sampling parents. Higher temperature increases the sensitivity to lower probability candidates. give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found. post_crossover_check (bool): If True, check if child produced from crossover step passes the constraints. max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints. Applied only when `post_crossover_check` is set to `True`. Setting it to 0 means we immediately take one of the parents at random as the child upon failure. """ def __init__( self, pop_size=60, max_iters=20, temp=0.3, give_up_if_no_improvement=False, post_crossover_check=True, max_crossover_retries=20, ): self.max_iters = max_iters self.pop_size = pop_size self.temp = temp self.give_up_if_no_improvement = give_up_if_no_improvement self.post_crossover_check = post_crossover_check self.max_crossover_retries = max_crossover_retries # internal flag to indicate if search should end immediately self._search_over = False @abstractmethod def _modify_population_member(self, pop_member, new_text, new_result, word_idx): """Modify `pop_member` by returning a new copy with `new_text`, `new_result`, and, `attributes` altered appropriately for given `word_idx`""" raise NotImplementedError() @abstractmethod def _get_word_select_prob_weights(self, pop_member): """Get the attribute of `pop_member` that is used for determining probability of each word being selected for perturbation.""" raise NotImplementedError def _perturb(self, pop_member, original_result, index=None): """Perturb `pop_member` and return it. Replaces a word at a random (unless `index` is specified) in `pop_member`. Args: pop_member (PopulationMember): The population member being perturbed. original_result (GoalFunctionResult): Result of original sample being attacked index (int): Index of word to perturb. Returns: Perturbed `PopulationMember` """ num_words = pop_member.attacked_text.num_words # `word_select_prob_weights` is a list of values used for sampling one word to transform word_select_prob_weights = np.copy( self._get_word_select_prob_weights(pop_member) ) non_zero_indices = np.count_nonzero(word_select_prob_weights) if non_zero_indices == 0: return pop_member iterations = 0 while iterations < non_zero_indices: if index: idx = index else: w_select_probs = word_select_prob_weights / np.sum( word_select_prob_weights ) idx = np.random.choice(num_words, 1, p=w_select_probs)[0] transformed_texts = self.get_transformations( pop_member.attacked_text, original_text=original_result.attacked_text, indices_to_modify=[idx], ) if not len(transformed_texts): iterations += 1 continue new_results, self._search_over = self.get_goal_results(transformed_texts) diff_scores = ( torch.Tensor([r.score for r in new_results]) - pop_member.result.score ) if len(diff_scores) and diff_scores.max() > 0: idx_with_max_score = diff_scores.argmax() pop_member = self._modify_population_member( pop_member, transformed_texts[idx_with_max_score], new_results[idx_with_max_score], idx, ) return pop_member word_select_prob_weights[idx] = 0 iterations += 1 if self._search_over: break return pop_member @abstractmethod def _crossover_operation(self, pop_member1, pop_member2): """Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two to generate crossover between `pop_member1` and `pop_member2`. Args: pop_member1 (PopulationMember): The first population member. pop_member2 (PopulationMember): The second population member. Returns: Tuple of `AttackedText` and a dictionary of attributes. """ raise NotImplementedError() def _post_crossover_check( self, new_text, parent_text1, parent_text2, original_text ): """Check if `new_text` that has been produced by performing crossover between `parent_text1` and `parent_text2` aligns with the constraints. Args: new_text (AttackedText): Text produced by crossover operation parent_text1 (AttackedText): Parent text of `new_text` parent_text2 (AttackedText): Second parent text of `new_text` original_text (AttackedText): Original text Returns: `True` if `new_text` meets the constraints. If otherwise, return `False`. """ if "last_transformation" in new_text.attack_attrs: previous_text = ( parent_text1 if "last_transformation" in parent_text1.attack_attrs else parent_text2 ) passed_constraints = self._check_constraints( new_text, previous_text, original_text=original_text ) return passed_constraints else: # `new_text` has not been actually transformed, so return True return True def _crossover(self, pop_member1, pop_member2, original_text): """Generates a crossover between pop_member1 and pop_member2. If the child fails to satisfy the constraints, we re-try crossover for a fix number of times, before taking one of the parents at random as the resulting child. Args: pop_member1 (PopulationMember): The first population member. pop_member2 (PopulationMember): The second population member. original_text (AttackedText): Original text Returns: A population member containing the crossover. """ x1_text = pop_member1.attacked_text x2_text = pop_member2.attacked_text num_tries = 0 passed_constraints = False while num_tries < self.max_crossover_retries + 1: new_text, attributes = self._crossover_operation(pop_member1, pop_member2) replaced_indices = new_text.attack_attrs["newly_modified_indices"] new_text.attack_attrs["modified_indices"] = ( x1_text.attack_attrs["modified_indices"] - replaced_indices ) | (x2_text.attack_attrs["modified_indices"] & replaced_indices) if "last_transformation" in x1_text.attack_attrs: new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[ "last_transformation" ] elif "last_transformation" in x2_text.attack_attrs: new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[ "last_transformation" ] if self.post_crossover_check: passed_constraints = self._post_crossover_check( new_text, x1_text, x2_text, original_text ) if not self.post_crossover_check or passed_constraints: break num_tries += 1 if self.post_crossover_check and not passed_constraints: # If we cannot find a child that passes the constraints, # we just randomly pick one of the parents to be the child for the next iteration. pop_mem = pop_member1 if np.random.uniform() < 0.5 else pop_member2 return pop_mem else: new_results, self._search_over = self.get_goal_results([new_text]) return PopulationMember( new_text, result=new_results[0], attributes=attributes ) @abstractmethod def _initialize_population(self, initial_result, pop_size): """ Initialize a population of size `pop_size` with `initial_result` Args: initial_result (GoalFunctionResult): Original text pop_size (int): size of population Returns: population as `list[PopulationMember]` """ raise NotImplementedError()
[docs] def check_transformation_compatibility(self, transformation): """The genetic algorithm is specifically designed for word substitutions.""" return transformation_consists_of_word_swaps(transformation)
@property def is_black_box(self): return True
[docs] def extra_repr_keys(self): return [ "pop_size", "max_iters", "temp", "give_up_if_no_improvement", "post_crossover_check", "max_crossover_retries", ]