Source code for textattack.constraints.grammaticality.part_of_speech

Part of Speech Constraint
import flair
from import Sentence
from flair.models import SequenceTagger
import lru
import nltk

import textattack
from textattack.constraints import Constraint
from textattack.shared.utils import LazyLoader, device
from textattack.shared.validators import transformation_consists_of_word_swaps

# Set global flair device to be TextAttack's current device
flair.device = device

stanza = LazyLoader("stanza", globals(), "stanza")

[docs]class PartOfSpeech(Constraint): """Constraints word swaps to only swap words with the same part of speech. Uses the NLTK universal part-of-speech tagger by default. An implementation of `<>`_ adapted from `<>`_. POS taggers from Flair `<>`_ and Stanza `<>`_ are also available Args: tagger_type (str): Name of the tagger to use (available choices: "nltk", "flair", "stanza"). tagset (str): tagset to use for POS tagging (e.g. "universal") allow_verb_noun_swap (bool): If `True`, allow verbs to be swapped with nouns and vice versa. compare_against_original (bool): If `True`, compare against the original text. Otherwise, compare against the most recent text. language_nltk: Language to be used for nltk POS-Tagger (available choices: "eng", "rus") language_stanza: Language to be used for stanza POS-Tagger (available choices: """ def __init__( self, tagger_type="nltk", tagset="universal", allow_verb_noun_swap=True, compare_against_original=True, language_nltk="eng", language_stanza="en", ): super().__init__(compare_against_original) self.tagger_type = tagger_type self.tagset = tagset self.allow_verb_noun_swap = allow_verb_noun_swap self.language_nltk = language_nltk self.language_stanza = language_stanza self._pos_tag_cache = lru.LRU(2**14) if tagger_type == "flair": if tagset == "universal": self._flair_pos_tagger = SequenceTagger.load("upos-fast") else: self._flair_pos_tagger = SequenceTagger.load("pos-fast") if tagger_type == "stanza": self._stanza_pos_tagger = stanza.Pipeline( lang=self.language_stanza, processors="tokenize, pos", tokenize_pretokenized=True, )
[docs] def clear_cache(self): self._pos_tag_cache.clear()
def _can_replace_pos(self, pos_a, pos_b): return (pos_a == pos_b) or ( self.allow_verb_noun_swap and set([pos_a, pos_b]) <= set(["NOUN", "VERB"]) ) def _get_pos(self, before_ctx, word, after_ctx): context_words = before_ctx + [word] + after_ctx context_key = " ".join(context_words) if context_key in self._pos_tag_cache: word_list, pos_list = self._pos_tag_cache[context_key] else: if self.tagger_type == "nltk": word_list, pos_list = zip( *nltk.pos_tag( context_words, tagset=self.tagset, lang=self.language_nltk ) ) if self.tagger_type == "flair": context_key_sentence = Sentence( context_key, use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) self._flair_pos_tagger.predict(context_key_sentence) word_list, pos_list = textattack.shared.utils.zip_flair_result( context_key_sentence ) if self.tagger_type == "stanza": word_list, pos_list = textattack.shared.utils.zip_stanza_result( self._stanza_pos_tagger(context_key), tagset=self.tagset ) self._pos_tag_cache[context_key] = (word_list, pos_list) # idx of `word` in `context_words` assert word in word_list, "POS list not matched with original word list." word_idx = word_list.index(word) return pos_list[word_idx] def _check_constraint(self, transformed_text, reference_text): try: indices = transformed_text.attack_attrs["newly_modified_indices"] except KeyError: raise KeyError( "Cannot apply part-of-speech constraint without `newly_modified_indices`" ) for i in indices: reference_word = reference_text.words[i] transformed_word = transformed_text.words[i] before_ctx = reference_text.words[max(i - 4, 0) : i] after_ctx = reference_text.words[ i + 1 : min(i + 4, len(reference_text.words)) ] ref_pos = self._get_pos(before_ctx, reference_word, after_ctx) replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx) if not self._can_replace_pos(ref_pos, replace_pos): return False return True
[docs] def check_compatibility(self, transformation): return transformation_consists_of_word_swaps(transformation)
[docs] def extra_repr_keys(self): return [ "tagger_type", "tagset", "allow_verb_noun_swap", ] + super().extra_repr_keys()
def __getstate__(self): state = self.__dict__.copy() state["_pos_tag_cache"] = self._pos_tag_cache.get_size() return state def __setstate__(self, state): self.__dict__ = state self._pos_tag_cache = lru.LRU(state["_pos_tag_cache"])