Source code for textattack.shared.attacked_text

""".. _attacked_text:

Attacked Text Class
=====================

A helper class that represents a string that can be attacked.
"""

from __future__ import annotations

from collections import OrderedDict
import math
from typing import Dict, Iterable, List, Optional, Set, Tuple

import flair
from flair.data import Sentence
import numpy as np
import torch

import textattack

from .utils import device, words_from_text

flair.device = device


[docs]class AttackedText: """A helper class that represents a string that can be attacked. Models that take multiple sentences as input separate them by ``SPLIT_TOKEN``. Attacks "see" the entire input, joined into one string, without the split token. ``AttackedText`` instances that were perturbed from other ``AttackedText`` objects contain a pointer to the previous text (``attack_attrs["previous_attacked_text"]``), so that the full chain of perturbations might be reconstructed by using this key to form a linked list. Args: text (string): The string that this AttackedText represents attack_attrs (dict): Dictionary of various attributes stored during the course of an attack. """ SPLIT_TOKEN = "<SPLIT>" def __init__(self, text_input, attack_attrs=None): # Read in ``text_input`` as a string or OrderedDict. if isinstance(text_input, str): self._text_input = OrderedDict([("text", text_input)]) elif isinstance(text_input, OrderedDict): self._text_input = text_input else: raise TypeError( f"Invalid text_input type {type(text_input)} (required str or OrderedDict)" ) # Process input lazily. self._words = None self._words_per_input = None self._pos_tags = None self._ner_tags = None # Format text inputs. self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()]) if attack_attrs is None: self.attack_attrs = dict() elif isinstance(attack_attrs, dict): self.attack_attrs = attack_attrs else: raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}") # Indices of words from the *original* text. Allows us to map # indices between original text and this text, and vice-versa. self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words)) # A list of all indices in *this* text that have been modified. self.attack_attrs.setdefault("modified_indices", set()) def __eq__(self, other: AttackedText) -> bool: """Compares two AttackedText instances. Note: Does not compute true equality across attack attributes. We found this caused large performance issues with caching, and it's actually much faster (cache-wise) to just compare by the text, and this works for lots of use cases. """ if not (self.text == other.text): return False if len(self.attack_attrs) != len(other.attack_attrs): return False return True def __hash__(self) -> int: return hash(self.text)
[docs] def free_memory(self): """Delete items that take up memory. Can be called once the AttackedText is only needed to display. """ if "previous_attacked_text" in self.attack_attrs: self.attack_attrs["previous_attacked_text"].free_memory() self.attack_attrs.pop("previous_attacked_text", None) self.attack_attrs.pop("last_transformation", None) for key in self.attack_attrs: if isinstance(self.attack_attrs[key], torch.Tensor): self.attack_attrs.pop(key, None)
[docs] def text_window_around_index(self, index: int, window_size: int) -> str: """The text window of ``window_size`` words centered around ``index``.""" length = self.num_words half_size = (window_size - 1) / 2.0 if index - half_size < 0: start = 0 end = min(window_size - 1, length - 1) elif index + half_size >= length: start = max(0, length - window_size) end = length - 1 else: start = index - math.ceil(half_size) end = index + math.floor(half_size) text_idx_start = self._text_index_of_word_index(start) text_idx_end = self._text_index_of_word_index(end) + len(self.words[end]) return self.text[text_idx_start:text_idx_end]
[docs] def pos_of_word_index(self, desired_word_idx: int) -> str: """Returns the part-of-speech of the word at index `word_idx`. Uses FLAIR part-of-speech tagger. Throws: ValueError, if no POS tag found for index. """ if not self._pos_tags: sentence = Sentence( self.text, use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) textattack.shared.utils.flair_tag(sentence) self._pos_tags = sentence flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result( self._pos_tags ) for word_idx, word in enumerate(self.words): assert ( word in flair_word_list ), "word absent in flair returned part-of-speech tags" word_idx_in_flair_tags = flair_word_list.index(word) if word_idx == desired_word_idx: return flair_pos_list[word_idx_in_flair_tags] else: flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :] raise ValueError( f"Did not find word from index {desired_word_idx} in flair POS tag" )
[docs] def ner_of_word_index(self, desired_word_idx: int, model_name="ner") -> str: """Returns the ner tag of the word at index `word_idx`. Uses FLAIR ner tagger. Throws: ValueError, if not NER tag found for index. """ if not self._ner_tags: sentence = Sentence( self.text, use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) textattack.shared.utils.flair_tag(sentence, model_name) self._ner_tags = sentence flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result( self._ner_tags, "ner" ) for word_idx, word in enumerate(flair_word_list): word_idx_in_flair_tags = flair_word_list.index(word) if word_idx == desired_word_idx: return flair_ner_list[word_idx_in_flair_tags] else: flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :] raise ValueError( f"Did not find word from index {desired_word_idx} in flair POS tag" )
def _text_index_of_word_index(self, i: int) -> int: """Returns the index of word ``i`` in self.text.""" pre_words = self.words[: i + 1] lower_text = self.text.lower() # Find all words until `i` in string. look_after_index = 0 for word in pre_words: look_after_index = lower_text.find(word.lower(), look_after_index) + len( word ) look_after_index -= len(self.words[i]) return look_after_index
[docs] def text_until_word_index(self, i: int) -> str: """Returns the text before the beginning of word at index ``i``.""" look_after_index = self._text_index_of_word_index(i) return self.text[:look_after_index]
[docs] def text_after_word_index(self, i: int) -> str: """Returns the text after the end of word at index ``i``.""" # Get index of beginning of word then jump to end of word. look_after_index = self._text_index_of_word_index(i) + len(self.words[i]) return self.text[look_after_index:]
[docs] def first_word_diff(self, other_attacked_text: AttackedText) -> Optional[str]: """Returns the first word in self.words that differs from other_attacked_text, or None if all words are the same. Useful for word swap strategies. """ w1 = self.words w2 = other_attacked_text.words for i in range(min(len(w1), len(w2))): if w1[i] != w2[i]: return w1[i] return None
[docs] def first_word_diff_index(self, other_attacked_text: AttackedText) -> Optional[int]: """Returns the index of the first word in self.words that differs from other_attacked_text. Useful for word swap strategies. """ w1 = self.words w2 = other_attacked_text.words for i in range(min(len(w1), len(w2))): if w1[i] != w2[i]: return i return None
[docs] def all_words_diff(self, other_attacked_text: AttackedText) -> Set[int]: """Returns the set of indices for which this and other_attacked_text have different words.""" indices = set() w1 = self.words w2 = other_attacked_text.words for i in range(min(len(w1), len(w2))): if w1[i] != w2[i]: indices.add(i) return indices
[docs] def ith_word_diff(self, other_attacked_text: AttackedText, i: int) -> bool: """Returns bool representing whether the word at index i differs from other_attacked_text.""" w1 = self.words w2 = other_attacked_text.words if len(w1) - 1 < i or len(w2) - 1 < i: return True return w1[i] != w2[i]
[docs] def words_diff_num(self, other_attacked_text: AttackedText) -> int: """The number of words different between two AttackedText objects.""" # using edit distance to calculate words diff num def generate_tokens(words): result = {} idx = 1 for w in words: if w not in result: result[w] = idx idx += 1 return result def words_to_tokens(words, tokens): result = [] for w in words: result.append(tokens[w]) return result def edit_distance(w1_t, w2_t): matrix = [ [i + j for j in range(len(w2_t) + 1)] for i in range(len(w1_t) + 1) ] for i in range(1, len(w1_t) + 1): for j in range(1, len(w2_t) + 1): if w1_t[i - 1] == w2_t[j - 1]: d = 0 else: d = 1 matrix[i][j] = min( matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d, ) return matrix[len(w1_t)][len(w2_t)] def cal_dif(w1, w2): tokens = generate_tokens(w1 + w2) w1_t = words_to_tokens(w1, tokens) w2_t = words_to_tokens(w2, tokens) return edit_distance(w1_t, w2_t) w1 = self.words w2 = other_attacked_text.words return cal_dif(w1, w2)
[docs] def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]: """Takes indices of words from original string and converts them to indices of the same words in the current string. Uses information from ``self.attack_attrs['original_index_map']``, which maps word indices from the original to perturbed text. """ if len(self.attack_attrs["original_index_map"]) == 0: return idxs elif isinstance(idxs, set): idxs = list(idxs) elif not isinstance(idxs, (list, np.ndarray)): raise TypeError( f"convert_from_original_idxs got invalid idxs type {type(idxs)}" ) return [self.attack_attrs["original_index_map"][i] for i in idxs]
[docs] def get_deletion_indices(self) -> Iterable[int]: return self.attack_attrs["original_index_map"][ self.attack_attrs["original_index_map"] == -1 ]
[docs] def replace_words_at_indices( self, indices: Iterable[int], new_words: Iterable[str] ) -> AttackedText: """Returns a new AttackedText object where the word at ``index`` is replaced with a new word.""" if len(indices) != len(new_words): raise ValueError( f"Cannot replace {len(new_words)} words at {len(indices)} indices." ) words = self.words[:] for i, new_word in zip(indices, new_words): if not isinstance(new_word, str): raise TypeError( f"replace_words_at_indices requires ``str`` words, got {type(new_word)}" ) if (i < 0) or (i > len(words)): raise ValueError(f"Cannot assign word at index {i}") words[i] = new_word return self.generate_new_attacked_text(words)
[docs] def replace_word_at_index(self, index: int, new_word: str) -> AttackedText: """Returns a new AttackedText object where the word at ``index`` is replaced with a new word.""" if not isinstance(new_word, str): raise TypeError( f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}" ) return self.replace_words_at_indices([index], [new_word])
[docs] def delete_word_at_index(self, index: int) -> AttackedText: """Returns a new AttackedText object where the word at ``index`` is removed.""" return self.replace_word_at_index(index, "")
[docs] def insert_text_after_word_index(self, index: int, text: str) -> AttackedText: """Inserts a string before word at index ``index`` and attempts to add appropriate spacing.""" if not isinstance(text, str): raise TypeError(f"text must be an str, got type {type(text)}") word_at_index = self.words[index] new_text = " ".join((word_at_index, text)) return self.replace_word_at_index(index, new_text)
[docs] def insert_text_before_word_index(self, index: int, text: str) -> AttackedText: """Inserts a string before word at index ``index`` and attempts to add appropriate spacing.""" if not isinstance(text, str): raise TypeError(f"text must be an str, got type {type(text)}") word_at_index = self.words[index] # TODO if ``word_at_index`` is at the beginning of a sentence, we should # optionally capitalize ``text``. new_text = " ".join((text, word_at_index)) return self.replace_word_at_index(index, new_text)
[docs] def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText: """Returns a new AttackedText object and replaces old list of words with a new list of words, but preserves the punctuation and spacing of the original message. ``self.words`` is a list of the words in the current text with punctuation removed. However, each "word" in ``new_words`` could be an empty string, representing a word deletion, or a string with multiple space-separated words, representation an insertion of one or more words. """ perturbed_text = "" original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values()) new_attack_attrs = dict() if "label_names" in self.attack_attrs: new_attack_attrs["label_names"] = self.attack_attrs["label_names"] new_attack_attrs["newly_modified_indices"] = set() # Point to previously monitored text. new_attack_attrs["previous_attacked_text"] = self # Use `new_attack_attrs` to track indices with respect to the original # text. new_attack_attrs["modified_indices"] = self.attack_attrs[ "modified_indices" ].copy() new_attack_attrs["original_index_map"] = self.attack_attrs[ "original_index_map" ].copy() new_i = 0 # Create the new attacked text by swapping out words from the original # text with a sequence of 0+ words in the new text. for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)): word_start = original_text.index(input_word) word_end = word_start + len(input_word) perturbed_text += original_text[:word_start] original_text = original_text[word_end:] adv_words = words_from_text(adv_word_seq) adv_num_words = len(adv_words) num_words_diff = adv_num_words - len(words_from_text(input_word)) # Track indices on insertions and deletions. if num_words_diff != 0: # Re-calculated modified indices. If words are inserted or deleted, # they could change. shifted_modified_indices = set() for modified_idx in new_attack_attrs["modified_indices"]: if modified_idx < i: shifted_modified_indices.add(modified_idx) elif modified_idx > i: shifted_modified_indices.add(modified_idx + num_words_diff) else: pass new_attack_attrs["modified_indices"] = shifted_modified_indices # Track insertions and deletions wrt original text. # original_modification_idx = i new_idx_map = new_attack_attrs["original_index_map"].copy() if num_words_diff == -1: # Word deletion new_idx_map[new_idx_map == i] = -1 new_idx_map[new_idx_map > i] += num_words_diff if num_words_diff > 0 and input_word != adv_words[0]: # If insertion happens before the `input_word` new_idx_map[new_idx_map == i] += num_words_diff new_attack_attrs["original_index_map"] = new_idx_map # Move pointer and save indices of new modified words. for j in range(i, i + adv_num_words): if input_word != adv_word_seq: new_attack_attrs["modified_indices"].add(new_i) new_attack_attrs["newly_modified_indices"].add(new_i) new_i += 1 # Check spaces for deleted text. if adv_num_words == 0 and len(original_text): # Remove extra space (or else there would be two spaces for each # deleted word). # @TODO What to do with punctuation in this case? This behavior is undefined. if i == 0: # If the first word was deleted, take a subsequent space. if original_text[0] == " ": original_text = original_text[1:] else: # If a word other than the first was deleted, take a preceding space. if perturbed_text[-1] == " ": perturbed_text = perturbed_text[:-1] # Add substitute word(s) to new sentence. perturbed_text += adv_word_seq perturbed_text += original_text # Add all of the ending punctuation. # Reform perturbed_text into an OrderedDict. perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN) perturbed_input = OrderedDict( zip(self._text_input.keys(), perturbed_input_texts) ) return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
[docs] def words_diff_ratio(self, x: AttackedText) -> float: """Get the ratio of words difference between current text and `x`. Note that current text and `x` must have same number of words. """ assert self.num_words == x.num_words return float(np.sum(self.words != x.words)) / self.num_words
[docs] def align_with_model_tokens( self, model_wrapper: textattack.models.wrappers.ModelWrapper ) -> Dict[int, Iterable[int]]: """Align AttackedText's `words` with target model's tokenization scheme (e.g. word, character, subword). Specifically, we map each word to list of indices of tokens that compose the word (e.g. embedding --> ["em", "##bed", "##ding"]) Args: model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model Returns: word2token_mapping (dict[int, list[int]]): Dictionary that maps i-th word to list of indices. """ tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0] word2token_mapping = {} j = 0 last_matched = 0 for i, word in enumerate(self.words): matched_tokens = [] while j < len(tokens) and len(word) > 0: token = tokens[j].lower() idx = word.lower().find(token) if idx == 0: word = word[idx + len(token) :] matched_tokens.append(j) last_matched = j j += 1 if not matched_tokens: word2token_mapping[i] = None j = last_matched else: word2token_mapping[i] = matched_tokens return word2token_mapping
@property def tokenizer_input(self) -> Tuple[str]: """The tuple of inputs to be passed to the tokenizer.""" input_tuple = tuple(self._text_input.values()) # Prefer to return a string instead of a tuple with a single value. if len(input_tuple) == 1: return input_tuple[0] else: return input_tuple @property def column_labels(self) -> List[str]: """Returns the labels for this text's columns. For single-sequence inputs, this simply returns ['text']. """ return list(self._text_input.keys()) @property def words_per_input(self) -> List[List[str]]: """Returns a list of lists of words corresponding to each input.""" if not self._words_per_input: self._words_per_input = [ words_from_text(_input) for _input in self._text_input.values() ] return self._words_per_input @property def words(self) -> List[str]: if not self._words: self._words = words_from_text(self.text) return self._words @property def text(self) -> str: """Represents full text input. Multiply inputs are joined with a line break. """ return "\n".join(self._text_input.values()) @property def num_words(self) -> int: """Returns the number of words in the sequence.""" return len(self.words) @property def newly_swapped_words(self) -> List[str]: return [ self.attack_attrs["prev_attacked_text"].words[i] for i in self.attack_attrs["newly_modified_indices"] ]
[docs] def printable_text(self, key_color="bold", key_color_method=None) -> str: """Represents full text input. Adds field descriptions. For example, entailment inputs look like: ``` premise: ... hypothesis: ... ``` """ # For single-sequence inputs, don't show a prefix. if len(self._text_input) == 1: return next(iter(self._text_input.values())) # For multiple-sequence inputs, show a prefix and a colon. Optionally, # color the key. else: if key_color_method: def ck(k): return textattack.shared.utils.color_text( k, key_color, key_color_method ) else: def ck(k): return k return "\n".join( f"{ck(key.capitalize())}: {value}" for key, value in self._text_input.items() )
def __repr__(self) -> str: return f'<AttackedText "{self.text}">'