"""
"Learning To Write" Language Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
import torch
import textattack
from textattack.constraints.grammaticality.language_models import (
LanguageModelConstraint,
)
from .language_model_helpers import QueryHandler
[docs]class LearningToWriteLanguageModel(LanguageModelConstraint):
"""A constraint based on the L2W language model.
The RNN-based language model from "Learning to Write With
Cooperative Discriminators" (Holtzman et al, 2018).
https://arxiv.org/pdf/1805.06087.pdf
https://github.com/windweller/l2w
Reused by Jia et al., 2019, as a substitution for the Google
1-billion words language model (in a revised version the attack of
Alzantot et al., 2018).
https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689
"""
CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write"
def __init__(self, window_size=5, **kwargs):
self.window_size = window_size
lm_folder_path = textattack.shared.utils.download_from_s3(
LearningToWriteLanguageModel.CACHE_PATH
)
self.query_handler = QueryHandler.load_model(
lm_folder_path, textattack.shared.utils.device
)
super().__init__(**kwargs)
[docs] def get_log_probs_at_index(self, text_list, word_index):
"""Gets the probability of the word at index `word_index` according to
the language model."""
queries = []
query_words = []
for attacked_text in text_list:
word = attacked_text.words[word_index]
window_text = attacked_text.text_window_around_index(
word_index, self.window_size
)
query = textattack.shared.utils.words_from_text(window_text)
queries.append(query)
query_words.append(word)
log_probs = self.query_handler.query(queries, query_words)
return torch.tensor(log_probs)