Source code for textattack.constraints.semantics.sentence_encoders.thought_vector

"""
Thought Vector Class
---------------------
"""

import functools

import torch

from textattack.shared import AbstractWordEmbedding, WordEmbedding, utils

from .sentence_encoder import SentenceEncoder


[docs]class ThoughtVector(SentenceEncoder): """A constraint on the distance between two sentences' thought vectors. Args: word_embedding (textattack.shared.AbstractWordEmbedding): The word embedding to use """ def __init__(self, embedding=None, **kwargs): if embedding is None: embedding = WordEmbedding.counterfitted_GLOVE_embedding() if not isinstance(embedding, AbstractWordEmbedding): raise ValueError( "`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`." ) self.word_embedding = embedding super().__init__(**kwargs)
[docs] def clear_cache(self): self._get_thought_vector.cache_clear()
@functools.lru_cache(maxsize=2**10) def _get_thought_vector(self, text): """Sums the embeddings of all the words in ``text`` into a "thought vector".""" embeddings = [] for word in utils.words_from_text(text): embedding = self.word_embedding[word] if embedding is not None: # out-of-vocab words do not have embeddings embeddings.append(embedding) embeddings = torch.tensor(embeddings) return torch.mean(embeddings, dim=0)
[docs] def encode(self, raw_text_list): return torch.stack([self._get_thought_vector(text) for text in raw_text_list])
[docs] def extra_repr_keys(self): """Set the extra representation of the constraint using these keys.""" return ["word_embedding"] + super().extra_repr_keys()