Source code for textattack.constraints.grammaticality.language_models.google_language_model.alzantot_goog_lm


Google Language Models from Alzantot

    Author: Moustafa Alzantot (
    All rights reserved.

import os

import lru
import numpy as np

from textattack.shared import utils

from . import lm_data_utils, lm_utils

tf = utils.LazyLoader("tensorflow", globals(), "tensorflow")

# @TODO automatically choose between GPU and CPU.

[docs]class GoogLMHelper: """An implementation of `<>`_ adapted from `<>`_.""" CACHE_PATH = "constraints/semantics/language-models/alzantot-goog-lm" def __init__(self): tf.get_logger().setLevel("INFO") lm_folder = utils.download_from_s3(GoogLMHelper.CACHE_PATH) self.PBTXT_PATH = os.path.join(lm_folder, "graph-2016-09-10-gpu.pbtxt") self.CKPT_PATH = os.path.join(lm_folder, "ckpt-*") self.VOCAB_PATH = os.path.join(lm_folder, "vocab-2016-09-10.txt") self.BATCH_SIZE = 1 self.NUM_TIMESTEPS = 1 self.MAX_WORD_LEN = 50 self.vocab = lm_data_utils.CharsVocabulary(self.VOCAB_PATH, self.MAX_WORD_LEN) with tf.device("/gpu:1"): self.graph = tf.Graph() self.sess = tf.compat.v1.Session(graph=self.graph) with self.graph.as_default(): self.t = lm_utils.LoadModel( self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH ) self.lm_cache = lru.LRU(2**18)
[docs] def clear_cache(self): self.lm_cache.clear()
[docs] def get_words_probs_uncached(self, prefix_words, list_words): targets = np.zeros([self.BATCH_SIZE, self.NUM_TIMESTEPS], np.int32) weights = np.ones([self.BATCH_SIZE, self.NUM_TIMESTEPS], np.float32) if prefix_words.find("<S>") != 0: prefix_words = "<S> " + prefix_words prefix = [self.vocab.word_to_id(w) for w in prefix_words.split()] prefix_char_ids = [self.vocab.word_to_char_ids(w) for w in prefix_words.split()] inputs = np.zeros([self.BATCH_SIZE, self.NUM_TIMESTEPS], np.int32) char_ids_inputs = np.zeros( [self.BATCH_SIZE, self.NUM_TIMESTEPS, self.vocab.max_word_length], np.int32 ) samples = prefix[:] char_ids_samples = prefix_char_ids[:] inputs = [[samples[-1]]] char_ids_inputs[0, 0, :] = char_ids_samples[-1] softmax = self.t["softmax_out"], feed_dict={ self.t["char_inputs_in"]: char_ids_inputs, self.t["inputs_in"]: inputs, self.t["targets_in"]: targets, self.t["target_weights_in"]: weights, }, ) words_ids = [self.vocab.word_to_id(w) for w in list_words] word_probs = [softmax[0][w_id] for w_id in words_ids] return np.array(word_probs)
[docs] def get_words_probs(self, prefix, list_words): """Retrieves the probability of words. Args: prefix_words list_words """ uncached_words = [] for word in list_words: if (prefix, word) not in self.lm_cache: if word not in uncached_words: uncached_words.append(word) probs = self.get_words_probs_uncached(prefix, uncached_words) for word, prob in zip(uncached_words, probs): self.lm_cache[prefix, word] = prob return [self.lm_cache[prefix, word] for word in list_words]
def __getstate__(self): state = self.__dict__.copy() state["lm_cache"] = self.lm_cache.get_size() return state def __setstate__(self, state): self.__dict__ = state self.lm_cache = lru.LRU(state["lm_cache"])