"""
Google 1-Billion Words Language Model
--------------------------------------
"""
from collections import defaultdict
import numpy as np
from textattack.constraints import Constraint
from textattack.transformations import WordSwap
from .alzantot_goog_lm import GoogLMHelper
[docs]class GoogleLanguageModel(Constraint):
"""Constraint that uses the Google 1 Billion Words Language Model to
determine the difference in perplexity between x and x_adv.
Args:
top_n (int):
top_n_per_index (int):
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
Otherwise, compare it against the previous `x_adv`.
"""
def __init__(self, top_n=None, top_n_per_index=None, compare_against_original=True):
if not (top_n or top_n_per_index):
raise ValueError(
"Cannot instantiate GoogleLanguageModel without top_n or top_n_per_index"
)
self.lm = GoogLMHelper()
self.top_n = top_n
self.top_n_per_index = top_n_per_index
super().__init__(compare_against_original)
[docs] def check_compatibility(self, transformation):
return isinstance(transformation, WordSwap)
def _check_constraint_many(self, transformed_texts, reference_text):
"""Returns the `top_n` of transformed_texts, as evaluated by the
language model."""
if not len(transformed_texts):
return []
def get_probs(reference_text, transformed_texts):
word_swap_index = reference_text.first_word_diff_index(transformed_texts[0])
if word_swap_index is None:
return []
prefix = reference_text.words[word_swap_index - 1]
swapped_words = np.array(
[t.words[word_swap_index] for t in transformed_texts]
)
probs = self.lm.get_words_probs(prefix, swapped_words)
return probs
# This creates a dictionary where each new key is initialized to [].
word_swap_index_map = defaultdict(list)
for idx, transformed_text in enumerate(transformed_texts):
word_swap_index = reference_text.first_word_diff_index(transformed_text)
word_swap_index_map[word_swap_index].append((idx, transformed_text))
probs = []
for word_swap_index, item_list in word_swap_index_map.items():
# zip(*some_list) is the inverse operator of zip!
item_indices, this_transformed_texts = zip(*item_list)
# t1 = time.time()
probs_of_swaps_at_index = list(
zip(item_indices, get_probs(reference_text, this_transformed_texts))
)
# Sort by probability in descending order and take the top n for this index.
probs_of_swaps_at_index.sort(key=lambda x: -x[1])
if self.top_n_per_index:
probs_of_swaps_at_index = probs_of_swaps_at_index[
: self.top_n_per_index
]
probs.extend(probs_of_swaps_at_index)
# t2 = time.time()
# Probs is a list of (index, prob) where index is the corresponding
# position in transformed_texts.
probs.sort(key=lambda x: x[0])
# Now that they're in order, reduce to just a list of probabilities.
probs = np.array(list(map(lambda x: x[1], probs)))
# Get the indices of the maximum elements.
max_el_indices = np.argsort(-probs)
if self.top_n:
max_el_indices = max_el_indices[: self.top_n]
# Put indices in order, now, so that the examples are returned in the
# same order they were passed in.
max_el_indices.sort()
return [transformed_texts[i] for i in max_el_indices]
def _check_constraint(self, transformed_text, reference_text):
return self._check_constraint_many([transformed_text], reference_text)