Source code for textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write

"""
"Learning To Write" Language Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

"""

import torch

import textattack
from textattack.constraints.grammaticality.language_models import (
    LanguageModelConstraint,
)

from .language_model_helpers import QueryHandler


[docs]class LearningToWriteLanguageModel(LanguageModelConstraint): """A constraint based on the L2W language model. The RNN-based language model from "Learning to Write With Cooperative Discriminators" (Holtzman et al, 2018). https://arxiv.org/pdf/1805.06087.pdf https://github.com/windweller/l2w Reused by Jia et al., 2019, as a substitution for the Google 1-billion words language model (in a revised version the attack of Alzantot et al., 2018). https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689 """ CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write" def __init__(self, window_size=5, **kwargs): self.window_size = window_size lm_folder_path = textattack.shared.utils.download_from_s3( LearningToWriteLanguageModel.CACHE_PATH ) self.query_handler = QueryHandler.load_model( lm_folder_path, textattack.shared.utils.device ) super().__init__(**kwargs)
[docs] def get_log_probs_at_index(self, text_list, word_index): """Gets the probability of the word at index `word_index` according to the language model.""" queries = [] query_words = [] for attacked_text in text_list: word = attacked_text.words[word_index] window_text = attacked_text.text_window_around_index( word_index, self.window_size ) query = textattack.shared.utils.words_from_text(window_text) queries.append(query) query_words.append(word) log_probs = self.query_handler.query(queries, query_words) return torch.tensor(log_probs)