Source code for textattack.transformations.word_swaps.word_swap_masked_lm

"""
Word Swap by BERT-Masked LM.
-------------------------------
"""

import itertools
import re

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

from textattack.shared import utils

from .word_swap import WordSwap


[docs]class WordSwapMaskedLM(WordSwap): """Generate potential replacements for a word using a masked language model. Based off of following papers - "Robustness to Modification with Shared Words in Paraphrase Identification" (Shi et al., 2019) https://arxiv.org/abs/1909.02560 - "BAE: BERT-based Adversarial Examples for Text Classification" (Garg et al., 2020) https://arxiv.org/abs/2004.01970 - "BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al, 2020) https://arxiv.org/abs/2004.09984 - "CLARE: Contextualized Perturbation for Textual Adversarial Attack" (Li et al, 2020): https://arxiv.org/abs/2009.07502 BAE and CLARE simply masks the word we want to replace and selects replacements predicted by the masked language model. BERT-Attack instead performs replacement on token level. For words that are consisted of two or more sub-word tokens, it takes the top-K replacements for seach sub-word token and produces all possible combinations of the top replacments. Then, it selects the top-K combinations based on their perplexity calculated using the masked language model. Choose which method to use by specifying "bae" or "bert-attack" for `method` argument. Args: method (str): the name of replacement method (e.g. "bae", "bert-attack") masked_language_model (Union[str|transformers.AutoModelForMaskedLM]): Either the name of pretrained masked language model from `transformers` model hub or the actual model. Default is `bert-base-uncased`. tokenizer (obj): The tokenizer of the corresponding model. If you passed in name of a pretrained model for `masked_language_model`, you can skip this argument as the correct tokenizer can be infered from the name. However, if you're passing the actual model, you must provide a tokenizer. max_length (int): the max sequence length the masked language model is designed to work with. Default is 512. window_size (int): The number of surrounding words to include when making top word prediction. For each word to swap, we take `window_size // 2` words to the left and `window_size // 2` words to the right and pass the text within the window to the masked language model. Default is `float("inf")`, which is equivalent to using the whole text. max_candidates (int): maximum number of candidates to consider as replacements for each word. Replacements are ranked by model's confidence. min_confidence (float): minimum confidence threshold each replacement word must pass. batch_size (int): Size of batch for "bae" replacement method. """ def __init__( self, method="bae", masked_language_model="bert-base-uncased", tokenizer=None, max_length=512, window_size=float("inf"), max_candidates=50, min_confidence=5e-4, batch_size=16, **kwargs, ): super().__init__(**kwargs) self.method = method self.max_length = max_length self.window_size = window_size self.max_candidates = max_candidates self.min_confidence = min_confidence self.batch_size = batch_size if isinstance(masked_language_model, str): self._language_model = AutoModelForMaskedLM.from_pretrained( masked_language_model ) self._lm_tokenizer = AutoTokenizer.from_pretrained( masked_language_model, use_fast=True ) else: self._language_model = masked_language_model if tokenizer is None: raise ValueError( "`tokenizer` argument must be provided when passing an actual model as `masked_language_model`." ) self._lm_tokenizer = tokenizer self._language_model.to(utils.device) self._language_model.eval() self.masked_lm_name = self._language_model.__class__.__name__ def _encode_text(self, text): """Encodes ``text`` using an ``AutoTokenizer``, ``self._lm_tokenizer``. Returns a ``dict`` where keys are strings (like 'input_ids') and values are ``torch.Tensor``s. Moves tensors to the same device as the language model. """ encoding = self._lm_tokenizer( text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) return encoding.to(utils.device) def _bae_replacement_words(self, current_text, indices_to_modify): """Get replacement words for the word we want to replace using BAE method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace """ masked_texts = [] for index in indices_to_modify: masked_text = current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token ) masked_texts.append(masked_text.text) i = 0 # 2-D list where for each index to modify we have a list of replacement words replacement_words = [] while i < len(masked_texts): inputs = self._encode_text(masked_texts[i : i + self.batch_size]) ids = inputs["input_ids"].tolist() with torch.no_grad(): preds = self._language_model(**inputs)[0] for j in range(len(ids)): try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = ids[j].index(self._lm_tokenizer.mask_token_id) except ValueError: replacement_words.append([]) continue mask_token_logits = preds[j, masked_index] mask_token_probs = torch.softmax(mask_token_logits, dim=0) ranked_indices = torch.argsort(mask_token_probs, descending=True) top_words = [] for _id in ranked_indices: _id = _id.item() word = self._lm_tokenizer.convert_ids_to_tokens(_id) if utils.check_if_subword( word, self._language_model.config.model_type, (masked_index == 1), ): word = utils.strip_BPE_artifacts( word, self._language_model.config.model_type ) if ( mask_token_probs[_id] >= self.min_confidence and utils.is_one_word(word) and not utils.check_if_punctuations(word) ): top_words.append(word) if ( len(top_words) >= self.max_candidates or mask_token_probs[_id] < self.min_confidence ): break replacement_words.append(top_words) i += self.batch_size return replacement_words def _bert_attack_replacement_words( self, current_text, index, id_preds, masked_lm_logits, ): """Get replacement words for the word we want to replace using BERT- Attack method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace id_preds (torch.Tensor): N x K tensor of top-K ids for each token-position predicted by the masked language model. N is equivalent to `self.max_length`. masked_lm_logits (torch.Tensor): N x V tensor of the raw logits outputted by the masked language model. N is equivlaent to `self.max_length` and V is dictionary size of masked language model. """ # We need to find which BPE tokens belong to the word we want to replace masked_text = current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token ) current_inputs = self._encode_text(masked_text.text) current_ids = current_inputs["input_ids"].tolist()[0] word_tokens = self._lm_tokenizer.encode( current_text.words[index], add_special_tokens=False ) try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = current_ids.index(self._lm_tokenizer.mask_token_id) except ValueError: return [] # List of indices of tokens that are part of the target word target_ids_pos = list( range(masked_index, min(masked_index + len(word_tokens), self.max_length)) ) if not len(target_ids_pos): return [] elif len(target_ids_pos) == 1: # Word to replace is tokenized as a single word top_preds = id_preds[target_ids_pos[0]].tolist() replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) if utils.is_one_word(token) and not utils.check_if_subword( token, self._language_model.config.model_type, index == 0 ): replacement_words.append(token) return replacement_words else: # Word to replace is tokenized as multiple sub-words top_preds = [id_preds[i] for i in target_ids_pos] products = itertools.product(*top_preds) combination_results = [] # Original BERT-Attack implement uses cross-entropy loss cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction="none") target_ids_pos_tensor = torch.tensor(target_ids_pos) word_tensor = torch.zeros(len(target_ids_pos), dtype=torch.long) for bpe_tokens in products: for i in range(len(bpe_tokens)): word_tensor[i] = bpe_tokens[i] logits = torch.index_select(masked_lm_logits, 0, target_ids_pos_tensor) loss = cross_entropy_loss(logits, word_tensor) perplexity = torch.exp(torch.mean(loss, dim=0)).item() word = "".join( self._lm_tokenizer.convert_ids_to_tokens(word_tensor) ).replace("##", "") if utils.is_one_word(word): combination_results.append((word, perplexity)) # Sort to get top-K results sorted(combination_results, key=lambda x: x[1]) top_replacements = [ x[0] for x in combination_results[: self.max_candidates] ] return top_replacements def _get_transformations(self, current_text, indices_to_modify): indices_to_modify = list(indices_to_modify) if self.method == "bert-attack": current_inputs = self._encode_text(current_text.text) with torch.no_grad(): pred_probs = self._language_model(**current_inputs)[0][0] top_probs, top_ids = torch.topk(pred_probs, self.max_candidates) id_preds = top_ids.cpu() masked_lm_logits = pred_probs.cpu() transformed_texts = [] for i in indices_to_modify: word_at_index = current_text.words[i] replacement_words = self._bert_attack_replacement_words( current_text, i, id_preds=id_preds, masked_lm_logits=masked_lm_logits, ) for r in replacement_words: r = r.strip("Ġ") if r != word_at_index: transformed_texts.append( current_text.replace_word_at_index(i, r) ) return transformed_texts elif self.method == "bae": replacement_words = self._bae_replacement_words( current_text, indices_to_modify ) transformed_texts = [] for i in range(len(replacement_words)): index_to_modify = indices_to_modify[i] word_at_index = current_text.words[index_to_modify] for word in replacement_words[i]: word = word.strip("Ġ") if ( word != word_at_index and re.search("[a-zA-Z]", word) and len(utils.words_from_text(word)) == 1 ): transformed_texts.append( current_text.replace_word_at_index(index_to_modify, word) ) return transformed_texts else: raise ValueError(f"Unrecognized value {self.method} for `self.method`.")
[docs] def extra_repr_keys(self): return [ "method", "masked_lm_name", "max_length", "max_candidates", "min_confidence", ]
[docs]def recover_word_case(word, reference_word): """Makes the case of `word` like the case of `reference_word`. Supports lowercase, UPPERCASE, and Capitalized. """ if reference_word.islower(): return word.lower() elif reference_word.isupper() and len(reference_word) > 1: return word.upper() elif reference_word[0].isupper() and reference_word[1:].islower(): return word.capitalize() else: # if other, just do not alter the word's case return word