Source code for textattack.constraints.semantics.sentence_encoders.infer_sent.infer_sent_model

"""
Infer sent model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This file contains the definition of encoders used in
https://arxiv.org/pdf/1705.02364.pdf.

"""


# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import time

import numpy as np
import torch
from torch import nn as nn

import textattack


[docs]class InferSentModel(nn.Module): def __init__(self, config): super().__init__() self.bsize = config["bsize"] self.word_emb_dim = config["word_emb_dim"] self.enc_lstm_dim = config["enc_lstm_dim"] self.pool_type = config["pool_type"] self.dpout_model = config["dpout_model"] self.version = 1 if "version" not in config else config["version"] self.enc_lstm = nn.LSTM( self.word_emb_dim, self.enc_lstm_dim, 1, bidirectional=True, dropout=self.dpout_model, ) assert self.version in [1, 2] if self.version == 1: self.bos = "<s>" self.eos = "</s>" self.max_pad = True self.moses_tok = False elif self.version == 2: self.bos = "<p>" self.eos = "</p>" self.max_pad = False self.moses_tok = True
[docs] def is_cuda(self): # either all weights are on cpu or they are on gpu return self.enc_lstm.bias_hh_l0.data.is_cuda
[docs] def forward(self, sent_tuple): # sent_len: [max_len, ..., min_len] (bsize) # sent: (seqlen x bsize x worddim) sent, sent_len = sent_tuple # Sort by length (keep idx) sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) sent_len_sorted = sent_len_sorted.copy() idx_unsort = np.argsort(idx_sort) idx_sort = ( torch.from_numpy(idx_sort).to(textattack.shared.utils.device) if self.is_cuda() else torch.from_numpy(idx_sort) ) sent = sent.index_select(1, idx_sort) # Handling padding in Recurrent Networks sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] # Un-sort by length idx_unsort = ( torch.from_numpy(idx_unsort).to(textattack.shared.utils.device) if self.is_cuda() else torch.from_numpy(idx_unsort) ) sent_output = sent_output.index_select(1, idx_unsort) # Pooling if self.pool_type == "mean": sent_len = ( torch.FloatTensor(sent_len.copy()) .unsqueeze(1) .to(textattack.shared.utils.device) ) emb = torch.sum(sent_output, 0).squeeze(0) emb = emb / sent_len.expand_as(emb) elif self.pool_type == "max": if not self.max_pad: sent_output[sent_output == 0] = -1e9 emb = torch.max(sent_output, 0)[0] if emb.ndimension() == 3: emb = emb.squeeze(0) assert emb.ndimension() == 2 return emb
[docs] def set_w2v_path(self, w2v_path): self.w2v_path = w2v_path
[docs] def get_word_dict(self, sentences, tokenize=True): # create vocab of words word_dict = {} sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] for sent in sentences: for word in sent: if word not in word_dict: word_dict[word] = "" word_dict[self.bos] = "" word_dict[self.eos] = "" return word_dict
[docs] def get_w2v(self, word_dict): assert hasattr(self, "w2v_path"), "w2v path not set" # create word_vec with w2v vectors word_vec = {} with open(self.w2v_path, encoding="utf-8") as f: for line in f: word, vec = line.split(" ", 1) if word in word_dict: word_vec[word] = np.fromstring(vec, sep=" ") print("Found %s(/%s) words with w2v vectors" % (len(word_vec), len(word_dict))) return word_vec
[docs] def get_w2v_k(self, K): assert hasattr(self, "w2v_path"), "w2v path not set" # create word_vec with k first w2v vectors k = 0 word_vec = {} with open(self.w2v_path, encoding="utf-8") as f: for line in f: word, vec = line.split(" ", 1) if k <= K: word_vec[word] = np.fromstring(vec, sep=" ") k += 1 if k > K: if word in [self.bos, self.eos]: word_vec[word] = np.fromstring(vec, sep=" ") if k > K and all([w in word_vec for w in [self.bos, self.eos]]): break return word_vec
[docs] def build_vocab(self, sentences, tokenize=True): assert hasattr(self, "w2v_path"), "w2v path not set" word_dict = self.get_word_dict(sentences, tokenize) self.word_vec = self.get_w2v(word_dict)
# print('Vocab size : %s' % (len(self.word_vec))) # build w2v vocab with k most frequent words
[docs] def build_vocab_k_words(self, K): assert hasattr(self, "w2v_path"), "w2v path not set" self.word_vec = self.get_w2v_k(K)
# print('Vocab size : %s' % (K))
[docs] def update_vocab(self, sentences, tokenize=True): assert hasattr(self, "w2v_path"), "warning : w2v path not set" assert hasattr(self, "word_vec"), "build_vocab before updating it" word_dict = self.get_word_dict(sentences, tokenize) # keep only new words for word in self.word_vec: if word in word_dict: del word_dict[word] # udpate vocabulary if word_dict: new_word_vec = self.get_w2v(word_dict) self.word_vec.update(new_word_vec) else: new_word_vec = [] print( "New vocab size : %s (added %s words)" % (len(self.word_vec), len(new_word_vec)) )
[docs] def get_batch(self, batch): # sent in batch in decreasing order of lengths # batch: (bsize, max_len, word_dim) embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) for i in range(len(batch)): for j in range(len(batch[i])): embed[j, i, :] = self.word_vec[batch[i][j]] return torch.FloatTensor(embed)
[docs] def tokenize(self, s): from nltk.tokenize import word_tokenize if self.moses_tok: s = " ".join(word_tokenize(s)) s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization return s.split() else: return word_tokenize(s)
[docs] def prepare_samples(self, sentences, bsize, tokenize, verbose): sentences = [ [self.bos] + s.split() + [self.eos] if not tokenize else [self.bos] + self.tokenize(s) + [self.eos] for s in sentences ] n_w = np.sum([len(x) for x in sentences]) # filters words without w2v vectors for i in range(len(sentences)): s_f = [word for word in sentences[i] if word in self.word_vec] if not s_f: import warnings warnings.warn( 'No words in "%s" (idx=%s) have w2v vectors. \ Replacing by "</s>"..' % (sentences[i], i) ) s_f = [self.eos] sentences[i] = s_f lengths = np.array([len(s) for s in sentences]) n_wk = np.sum(lengths) if verbose: print( "Nb words kept : %s/%s (%.1f%s)" % (n_wk, n_w, 100.0 * n_wk / n_w, "%") ) # sort by decreasing length lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) sentences = np.array(sentences)[idx_sort] return sentences, lengths, idx_sort
[docs] def encode(self, sentences, bsize=64, tokenize=True, verbose=False): tic = time.time() sentences, lengths, idx_sort = self.prepare_samples( sentences, bsize, tokenize, verbose ) embeddings = [] for stidx in range(0, len(sentences), bsize): batch = self.get_batch(sentences[stidx : stidx + bsize]) if self.is_cuda(): batch = batch.to(textattack.shared.utils.device) with torch.no_grad(): batch = ( self.forward((batch, lengths[stidx : stidx + bsize])) .data.cpu() .numpy() ) embeddings.append(batch) embeddings = np.vstack(embeddings) # unsort idx_unsort = np.argsort(idx_sort) embeddings = embeddings[idx_unsort] if verbose: print( "Speed : %.1f sentences/s (%s mode, bsize=%s)" % ( len(embeddings) / (time.time() - tic), "gpu" if self.is_cuda() else "cpu", bsize, ) ) return embeddings