Source code for textattack.transformations.word_swaps.word_swap_inflections

Word Swap by inflections


import random

import lemminflect

from .word_swap import WordSwap

[docs]class WordSwapInflections(WordSwap): """Transforms an input by replacing its words with their inflections. For example, the inflections of 'schedule' are {'schedule', 'schedules', 'scheduling'}. Base on ``It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations``. `Paper URL`_ .. _Paper URL: """ def __init__(self, **kwargs): super().__init__(**kwargs) # fine-grained en-ptb POS to universal POS mapping # (mapping info: self._enptb_to_universal = { "JJRJR": "ADJ", "VBN": "VERB", "VBP": "VERB", "JJ": "ADJ", "VBZ": "VERB", "VBG": "VERB", "NN": "NOUN", "VBD": "VERB", "NP": "NOUN", "NNP": "NOUN", "VB": "VERB", "NNS": "NOUN", "VP": "VERB", "TO": "VERB", "SYM": "NOUN", "MD": "VERB", "NNPS": "NOUN", "JJS": "ADJ", "JJR": "ADJ", "RB": "ADJ", } def _get_replacement_words(self, word, word_part_of_speech): # only nouns, verbs, and adjectives are considered for replacement if word_part_of_speech not in self._enptb_to_universal: return [] # gets a dict that maps part-of-speech (POS) to available lemmas replacement_inflections_dict = lemminflect.getAllLemmas(word) # if dict is empty, there are no replacements for this word if not replacement_inflections_dict: return [] # map the fine-grained POS to a universal POS lemminflect_pos = self._enptb_to_universal[word_part_of_speech] # choose lemma with same POS, if ones exists; otherwise, choose lemma randomly if lemminflect_pos in replacement_inflections_dict: lemma = replacement_inflections_dict[lemminflect_pos][0] else: lemma = random.choice(list(replacement_inflections_dict.values()))[0] # get the available inflections for chosen lemma inflections = lemminflect.getAllInflections( lemma, upos=lemminflect_pos ).values() # merge tuples, remove duplicates, remove copy of the original word replacement_words = list(set([infl for tup in inflections for infl in tup])) replacement_words = [r for r in replacement_words if r != word] return replacement_words def _get_transformations(self, current_text, indices_to_modify): transformed_texts = [] for i in indices_to_modify: word_to_replace = current_text.words[i] word_to_replace_pos = current_text.pos_of_word_index(i) replacement_words = ( self._get_replacement_words(word_to_replace, word_to_replace_pos) or [] ) for r in replacement_words: transformed_texts.append(current_text.replace_word_at_index(i, r)) return transformed_texts