"""
Infer sent model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This file contains the definition of encoders used in
https://arxiv.org/pdf/1705.02364.pdf.
"""
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import time
import numpy as np
import torch
from torch import nn as nn
import textattack
[docs]class InferSentModel(nn.Module):
def __init__(self, config):
super().__init__()
self.bsize = config["bsize"]
self.word_emb_dim = config["word_emb_dim"]
self.enc_lstm_dim = config["enc_lstm_dim"]
self.pool_type = config["pool_type"]
self.dpout_model = config["dpout_model"]
self.version = 1 if "version" not in config else config["version"]
self.enc_lstm = nn.LSTM(
self.word_emb_dim,
self.enc_lstm_dim,
1,
bidirectional=True,
dropout=self.dpout_model,
)
assert self.version in [1, 2]
if self.version == 1:
self.bos = "<s>"
self.eos = "</s>"
self.max_pad = True
self.moses_tok = False
elif self.version == 2:
self.bos = "<p>"
self.eos = "</p>"
self.max_pad = False
self.moses_tok = True
[docs] def is_cuda(self):
# either all weights are on cpu or they are on gpu
return self.enc_lstm.bias_hh_l0.data.is_cuda
[docs] def forward(self, sent_tuple):
# sent_len: [max_len, ..., min_len] (bsize)
# sent: (seqlen x bsize x worddim)
sent, sent_len = sent_tuple
# Sort by length (keep idx)
sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
sent_len_sorted = sent_len_sorted.copy()
idx_unsort = np.argsort(idx_sort)
idx_sort = (
torch.from_numpy(idx_sort).to(textattack.shared.utils.device)
if self.is_cuda()
else torch.from_numpy(idx_sort)
)
sent = sent.index_select(1, idx_sort)
# Handling padding in Recurrent Networks
sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted)
sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
# Un-sort by length
idx_unsort = (
torch.from_numpy(idx_unsort).to(textattack.shared.utils.device)
if self.is_cuda()
else torch.from_numpy(idx_unsort)
)
sent_output = sent_output.index_select(1, idx_unsort)
# Pooling
if self.pool_type == "mean":
sent_len = (
torch.FloatTensor(sent_len.copy())
.unsqueeze(1)
.to(textattack.shared.utils.device)
)
emb = torch.sum(sent_output, 0).squeeze(0)
emb = emb / sent_len.expand_as(emb)
elif self.pool_type == "max":
if not self.max_pad:
sent_output[sent_output == 0] = -1e9
emb = torch.max(sent_output, 0)[0]
if emb.ndimension() == 3:
emb = emb.squeeze(0)
assert emb.ndimension() == 2
return emb
[docs] def set_w2v_path(self, w2v_path):
self.w2v_path = w2v_path
[docs] def get_word_dict(self, sentences, tokenize=True):
# create vocab of words
word_dict = {}
sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences]
for sent in sentences:
for word in sent:
if word not in word_dict:
word_dict[word] = ""
word_dict[self.bos] = ""
word_dict[self.eos] = ""
return word_dict
[docs] def get_w2v(self, word_dict):
assert hasattr(self, "w2v_path"), "w2v path not set"
# create word_vec with w2v vectors
word_vec = {}
with open(self.w2v_path, encoding="utf-8") as f:
for line in f:
word, vec = line.split(" ", 1)
if word in word_dict:
word_vec[word] = np.fromstring(vec, sep=" ")
print("Found %s(/%s) words with w2v vectors" % (len(word_vec), len(word_dict)))
return word_vec
[docs] def get_w2v_k(self, K):
assert hasattr(self, "w2v_path"), "w2v path not set"
# create word_vec with k first w2v vectors
k = 0
word_vec = {}
with open(self.w2v_path, encoding="utf-8") as f:
for line in f:
word, vec = line.split(" ", 1)
if k <= K:
word_vec[word] = np.fromstring(vec, sep=" ")
k += 1
if k > K:
if word in [self.bos, self.eos]:
word_vec[word] = np.fromstring(vec, sep=" ")
if k > K and all([w in word_vec for w in [self.bos, self.eos]]):
break
return word_vec
[docs] def build_vocab(self, sentences, tokenize=True):
assert hasattr(self, "w2v_path"), "w2v path not set"
word_dict = self.get_word_dict(sentences, tokenize)
self.word_vec = self.get_w2v(word_dict)
# print('Vocab size : %s' % (len(self.word_vec)))
# build w2v vocab with k most frequent words
[docs] def build_vocab_k_words(self, K):
assert hasattr(self, "w2v_path"), "w2v path not set"
self.word_vec = self.get_w2v_k(K)
# print('Vocab size : %s' % (K))
[docs] def update_vocab(self, sentences, tokenize=True):
assert hasattr(self, "w2v_path"), "warning : w2v path not set"
assert hasattr(self, "word_vec"), "build_vocab before updating it"
word_dict = self.get_word_dict(sentences, tokenize)
# keep only new words
for word in self.word_vec:
if word in word_dict:
del word_dict[word]
# udpate vocabulary
if word_dict:
new_word_vec = self.get_w2v(word_dict)
self.word_vec.update(new_word_vec)
else:
new_word_vec = []
print(
"New vocab size : %s (added %s words)"
% (len(self.word_vec), len(new_word_vec))
)
[docs] def get_batch(self, batch):
# sent in batch in decreasing order of lengths
# batch: (bsize, max_len, word_dim)
embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim))
for i in range(len(batch)):
for j in range(len(batch[i])):
embed[j, i, :] = self.word_vec[batch[i][j]]
return torch.FloatTensor(embed)
[docs] def tokenize(self, s):
from nltk.tokenize import word_tokenize
if self.moses_tok:
s = " ".join(word_tokenize(s))
s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization
return s.split()
else:
return word_tokenize(s)
[docs] def prepare_samples(self, sentences, bsize, tokenize, verbose):
sentences = [
(
[self.bos] + s.split() + [self.eos]
if not tokenize
else [self.bos] + self.tokenize(s) + [self.eos]
)
for s in sentences
]
n_w = np.sum([len(x) for x in sentences])
# filters words without w2v vectors
for i in range(len(sentences)):
s_f = [word for word in sentences[i] if word in self.word_vec]
if not s_f:
import warnings
warnings.warn(
'No words in "%s" (idx=%s) have w2v vectors. \
Replacing by "</s>"..'
% (sentences[i], i)
)
s_f = [self.eos]
sentences[i] = s_f
lengths = np.array([len(s) for s in sentences])
n_wk = np.sum(lengths)
if verbose:
print(
"Nb words kept : %s/%s (%.1f%s)" % (n_wk, n_w, 100.0 * n_wk / n_w, "%")
)
# sort by decreasing length
lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths)
sentences = np.array(sentences)[idx_sort]
return sentences, lengths, idx_sort
[docs] def encode(self, sentences, bsize=64, tokenize=True, verbose=False):
tic = time.time()
sentences, lengths, idx_sort = self.prepare_samples(
sentences, bsize, tokenize, verbose
)
embeddings = []
for stidx in range(0, len(sentences), bsize):
batch = self.get_batch(sentences[stidx : stidx + bsize])
if self.is_cuda():
batch = batch.to(textattack.shared.utils.device)
with torch.no_grad():
batch = (
self.forward((batch, lengths[stidx : stidx + bsize]))
.data.cpu()
.numpy()
)
embeddings.append(batch)
embeddings = np.vstack(embeddings)
# unsort
idx_unsort = np.argsort(idx_sort)
embeddings = embeddings[idx_unsort]
if verbose:
print(
"Speed : %.1f sentences/s (%s mode, bsize=%s)"
% (
len(embeddings) / (time.time() - tic),
"gpu" if self.is_cuda() else "cpu",
bsize,
)
)
return embeddings