Source code for textattack.constraints.semantics.sentence_encoders.infer_sent.infer_sent

"""
infer sent for sentence similarity
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

import os

import torch

from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.shared import utils

from .infer_sent_model import InferSentModel


[docs]class InferSent(SentenceEncoder): """Constraint using similarity between sentence encodings of x and x_adv where the text embeddings are created using InferSent.""" MODEL_PATH = "constraints/semantics/sentence-encoders/infersent-encoder" WORD_EMBEDDING_PATH = "word_embeddings" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model = self.get_infersent_model() self.model.to(utils.device)
[docs] def get_infersent_model(self): """Retrieves the InferSent model. Returns: The pretrained InferSent model. """ infersent_version = 2 model_folder_path = utils.download_from_s3(InferSent.MODEL_PATH) model_path = os.path.join( model_folder_path, f"infersent{infersent_version}.pkl" ) params_model = { "bsize": 64, "word_emb_dim": 300, "enc_lstm_dim": 2048, "pool_type": "max", "dpout_model": 0.0, "version": infersent_version, } infersent = InferSentModel(params_model) infersent.load_state_dict(torch.load(model_path)) word_embedding_path = utils.download_from_s3(InferSent.WORD_EMBEDDING_PATH) w2v_path = os.path.join(word_embedding_path, "fastText", "crawl-300d-2M.vec") infersent.set_w2v_path(w2v_path) infersent.build_vocab_k_words(K=100000) return infersent
[docs] def encode(self, sentences): return self.model.encode(sentences, tokenize=True)