Source code for textattack.constraints.grammaticality.language_models.learning_to_write.language_model_helpers

Language model helpers

import os

import numpy as np
import torch

from textattack.shared.utils import LazyLoader

from .rnn_model import RNNModel

torchfile = LazyLoader("torchfile", globals(), "torchfile")

[docs]class QueryHandler: def __init__(self, model, word_to_idx, mapto, device): self.model = model self.word_to_idx = word_to_idx self.mapto = mapto self.device = device
[docs] def query(self, sentences, swapped_words, batch_size=32): """Since we don't filter prefixes for OOV ahead of time, it's possible that some of them will have different lengths. When this is the case, we can't do RNN prediction in batch. This method _tries_ to do prediction in batch, and, when it fails, just does prediction sequentially and concatenates all of the results. """ try: return self.try_query(sentences, swapped_words, batch_size=batch_size) except Exception: probs = [] for s, w in zip(sentences, swapped_words): try: probs.append(self.try_query([s], [w], batch_size=1)[0]) except RuntimeError: print( "WARNING: got runtime error trying languag emodel on language model w s/w", s, w, ) probs.append(float("-inf")) return probs
[docs] def try_query(self, sentences, swapped_words, batch_size=32): # TODO use caching sentence_length = len(sentences[0]) if any(len(s) != sentence_length for s in sentences): raise ValueError("Only same length batches are allowed") log_probs = [] for start in range(0, len(sentences), batch_size): swapped_words_batch = swapped_words[ start : min(len(sentences), start + batch_size) ] batch = sentences[start : min(len(sentences), start + batch_size)] raw_idx_list = [[] for i in range(sentence_length + 1)] for i, s in enumerate(batch): s = [word for word in s if word in self.word_to_idx] words = ["<S>"] + s word_idxs = [self.word_to_idx[w] for w in words] for t in range(sentence_length + 1): if t < len(word_idxs): raw_idx_list[t].append(word_idxs[t]) orig_num_idxs = len(raw_idx_list) raw_idx_list = [x for x in raw_idx_list if len(x)] num_idxs_dropped = orig_num_idxs - len(raw_idx_list) all_raw_idxs = torch.tensor( raw_idx_list, device=self.device, dtype=torch.long ) word_idxs = self.mapto[all_raw_idxs] hidden = self.model.init_hidden(len(batch)) source = word_idxs[:-1, :] target = word_idxs[1:, :] if (not len(source)) or not len(hidden): return [float("-inf")] * len(batch) decode, hidden = self.model(source, hidden) decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1) for i in range(len(batch)): if swapped_words_batch[i] not in self.word_to_idx: log_probs.append(float("-inf")) else: log_probs.append( sum( [ decode[t, i, target[t, i]].item() for t in range(sentence_length - num_idxs_dropped) ] ) ) return log_probs
[docs] @staticmethod def load_model(lm_folder_path, device): word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7")) word_map = [w.decode("utf-8") for w in word_map] word_to_idx = {w: i for i, w in enumerate(word_map)} word_freq = torchfile.load( os.path.join(os.path.join(lm_folder_path, "word_freq.th7")) ) mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device) model_file = open(os.path.join(lm_folder_path, ""), "rb") model = RNNModel( "GRU", 793471, 256, 2048, 1, [4200, 35000, 180000, 793471], dropout=0.01, proj=True, lm1b=True, ) model.load_state_dict(torch.load(model_file, map_location=device)) model.full = True # Use real softmax--important! model.eval() model_file.close() return QueryHandler(model, word_to_idx, mapto, device)
[docs]def util_reverse(item): new_item = np.zeros(len(item)) for idx, val in enumerate(item): new_item[val] = idx return new_item