TextAttack with Custom Dataset and Word Embedding.
This tutorial will show you how to use textattack with any dataset and word embedding you may want to use
Please remember to run pip3 install textattack[tensorflow] in your notebook enviroment before the following codes:
Importing the Model
We start by choosing a pretrained model we want to attack. In this example we will use the albert base v2 model from HuggingFace. This model was trained with data from imbd, a set of movie reviews with either positive or negative labels.
[1]:
!pip3 install textattack[tensorflow]
Requirement already satisfied: textattack in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (0.3.0)
Requirement already satisfied: editdistance in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.5.3)
Requirement already satisfied: num2words in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.5.10)
Requirement already satisfied: flair in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.6.1.post1)
Requirement already satisfied: bert-score>=0.3.5 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.3.7)
Requirement already satisfied: tqdm<4.50.0,>=4.27 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (4.49.0)
Requirement already satisfied: nltk in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.5)
Requirement already satisfied: transformers>=3.3.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (4.1.1)
Requirement already satisfied: word2number in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1)
Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.0.12)
Requirement already satisfied: terminaltables in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.1.0)
Requirement already satisfied: lru-dict in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1.6)
Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1.3)
Requirement already satisfied: language-tool-python in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (2.4.7)
Requirement already satisfied: more-itertools in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (8.6.0)
Requirement already satisfied: pandas>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.2.0)
Requirement already satisfied: torch!=1.8,>=1.7.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.7.1)
Requirement already satisfied: lemminflect in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.2.1)
Requirement already satisfied: numpy<1.19.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.18.5)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.7.1)
Requirement already satisfied: scipy==1.4.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.4.1)
Requirement already satisfied: docopt>=0.6.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from num2words->textattack) (0.6.2)
Requirement already satisfied: gdown in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.12.2)
Requirement already satisfied: tabulate in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.8.7)
Requirement already satisfied: konoha<5.0.0,>=4.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (4.6.2)
Requirement already satisfied: python-dateutil>=2.6.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (2.8.1)
Requirement already satisfied: ftfy in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (5.8)
Requirement already satisfied: regex in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (2020.11.13)
Requirement already satisfied: segtok>=1.5.7 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.5.10)
Requirement already satisfied: gensim>=3.4.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.8.3)
Requirement already satisfied: matplotlib>=2.2.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.3.3)
Requirement already satisfied: bpemb>=0.3.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.3.2)
Requirement already satisfied: janome in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.4.1)
Requirement already satisfied: langdetect in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.0.8)
Requirement already satisfied: hyperopt>=0.1.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.2.5)
Requirement already satisfied: sentencepiece!=0.1.92 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.1.94)
Requirement already satisfied: lxml in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (4.6.2)
Requirement already satisfied: deprecated>=1.2.4 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.2.10)
Requirement already satisfied: scikit-learn>=0.21.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.24.0)
Requirement already satisfied: mpld3==0.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.3)
Requirement already satisfied: sqlitedict>=1.6.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.7.0)
Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from bert-score>=0.3.5->textattack) (2.25.1)
Requirement already satisfied: click in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from nltk->textattack) (7.1.2)
Requirement already satisfied: joblib in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from nltk->textattack) (1.0.0)
Requirement already satisfied: packaging in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (20.8)
Requirement already satisfied: sacremoses in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (0.0.43)
Requirement already satisfied: tokenizers==0.9.4 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (0.9.4)
Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (0.70.11.1)
Requirement already satisfied: pyarrow>=0.17.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (2.0.0)
Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (2.0.0)
Requirement already satisfied: dill in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (0.3.3)
Requirement already satisfied: pytz>=2017.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from pandas>=1.0.1->textattack) (2020.5)
Requirement already satisfied: typing-extensions in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from torch!=1.8,>=1.7.0->textattack) (3.7.4.3)
Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gdown->flair->textattack) (1.15.0)
Requirement already satisfied: overrides==3.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from konoha<5.0.0,>=4.0.0->flair->textattack) (3.0.0)
Requirement already satisfied: wcwidth in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from ftfy->flair->textattack) (0.2.5)
Requirement already satisfied: smart-open>=1.8.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gensim>=3.4.0->flair->textattack) (4.1.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (0.10.0)
Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (8.0.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (2.4.7)
Requirement already satisfied: future in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (0.18.2)
Requirement already satisfied: networkx>=2.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (2.5)
Requirement already satisfied: cloudpickle in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (1.6.0)
Requirement already satisfied: wrapt<2,>=1.10 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from deprecated>=1.2.4->flair->textattack) (1.12.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from scikit-learn>=0.21.3->flair->textattack) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (2020.12.5)
Requirement already satisfied: chardet<5,>=3.0.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (4.0.0)
Requirement already satisfied: idna<3,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (2.10)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (1.26.2)
Requirement already satisfied: decorator>=4.3.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from networkx>=2.2->hyperopt>=0.1.1->flair->textattack) (4.4.2)
WARNING: You are using pip version 20.1.1; however, version 21.1.3 is available.
You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.7/bin/python3.7 -m pip install --upgrade pip' command.
[2]:
import transformers
from textattack.models.wrappers import HuggingFaceModelWrapper
# https://huggingface.co/textattack
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"textattack/albert-base-v2-imdb"
)
tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/albert-base-v2-imdb")
# We wrap the model so it can be used by textattack
model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
Creating A Custom Dataset
Textattack takes in dataset in the form of a list of tuples. The tuple can be in the form of (“string”, label) or (“string”, label, label). In this case we will use former one, since we want to create a custom movie review dataset with label 0 representing a positive review, and label 1 representing a negative review.
For simplicity, I created a dataset consisting of 4 reviews, the 1st and 4th review have “correct” labels, while the 2nd and 3rd review have “incorrect” labels. We will see how this impacts perturbation later in this tutorial.
[3]:
# dataset: An iterable of (text, ground_truth_output) pairs.
# 0 means the review is negative
# 1 means the review is positive
custom_dataset = [
("I hate this movie", 0), # A negative comment, with a negative label
("I hate this movie", 1), # A negative comment, with a positive label
("I love this movie", 0), # A positive comment, with a negative label
("I love this movie", 1), # A positive comment, with a positive label
]
Creating An Attack
[4]:
from textattack import Attack
from textattack.search_methods import GreedySearch
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.transformations import WordSwapEmbedding
from textattack.constraints.pre_transformation import RepeatModification
from textattack.constraints.pre_transformation import StopwordModification
# We'll use untargeted classification as the goal function.
goal_function = UntargetedClassification(model_wrapper)
# We'll to use our WordSwapEmbedding as the attack transformation.
transformation = WordSwapEmbedding()
# We'll constrain modification of already modified indices and stopwords
constraints = [RepeatModification(), StopwordModification()]
# We'll use the Greedy search method
search_method = GreedySearch()
# Now, let's make the attack from the 4 components:
attack = Attack(goal_function, constraints, transformation, search_method)
textattack: Unknown if model of class <class 'transformers.models.albert.modeling_albert.AlbertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
Attack Results With Custom Dataset
As you can see, the attack fools the model by changing a few words in the 1st and 4th review.
The attack skipped the 2nd and and 3rd review because since it they were labeled incorrectly, they managed to fool the model without any modifications.
[5]:
for example, label in custom_dataset:
result = attack.attack(example, label)
print(result.__str__(color_method="ansi"))
0 (99%) --> 1 (81%)
I hate this movie
did hateful this footage
0 (99%) --> [SKIPPED]
I hate this movie
1 (96%) --> [SKIPPED]
I love this movie
1 (96%) --> 0 (99%)
I love this movie
I iove this movie
Creating A Custom Word Embedding
In textattack, a pre-trained word embedding is necessary in transformation in order to find synonym replacements, and in constraints to check the semantic validity of the transformation. To use custom pre-trained word embeddings, you can either create a new class that inherits the AbstractWordEmbedding class, or use the WordEmbedding class which takes in 4 parameters.
[6]:
from textattack.shared import WordEmbedding
embedding_matrix = [
[1.0],
[2.0],
[3.0],
[4.0],
] # 2-D array of shape N x D where N represents size of vocab and D is the dimension of embedding vectors.
word2index = {
"hate": 0,
"despise": 1,
"like": 2,
"love": 3,
} # dictionary that maps word to its index with in the embedding matrix.
index2word = {
0: "hate",
1: "despise",
2: "like",
3: "love",
} # dictionary that maps index to its word.
nn_matrix = [
[0, 1, 2, 3],
[1, 0, 2, 3],
[2, 1, 3, 0],
[3, 2, 1, 0],
] # 2-D integer array of shape N x K where N represents size of vocab and K is the top-K nearest neighbours.
embedding = WordEmbedding(embedding_matrix, word2index, index2word, nn_matrix)
Attack Results With Custom Dataset and Word Embedding
Now if we run the attack again with the custom word embedding, you will notice the modifications are limited to the vocab provided by our custom word embedding.
[7]:
from textattack.attack_results import SuccessfulAttackResult
transformation = WordSwapEmbedding(3, embedding)
attack = Attack(goal_function, constraints, transformation, search_method)
# here is a legacy code piece showing how the attack runs in details
for example, label in custom_dataset:
result = attack.attack(example, label)
print(result.__str__(color_method="ansi"))
0 (99%) --> 1 (98%)
I hate this movie
I like this movie
0 (99%) --> [SKIPPED]
I hate this movie
1 (96%) --> [SKIPPED]
I love this movie
1 (96%) --> 0 (99%)
I love this movie
I despise this movie
[ ]:
# here is currently recommendated API-centric way to use customized attack
from textattack.loggers import CSVLogger # tracks a dataframe for us.
from textattack.attack_results import SuccessfulAttackResult
from textattack import Attacker, AttackArgs
attack_args = AttackArgs(
num_successful_examples=5, log_to_csv="results.csv", csv_coloring_style="html"
)
attacker = Attacker(attack, custom_dataset, attack_args)
attack_results = attacker.attack_dataset()
[ ]:
# now we visualize the attack results
import pandas as pd
pd.options.display.max_colwidth = (
480 # increase colum width so we can actually read the examples
)
logger = CSVLogger(color_method="html")
for result in attack_results:
if isinstance(result, SuccessfulAttackResult):
logger.log_attack_result(result)
from IPython.core.display import display, HTML
results = pd.DataFrame.from_records(logger.row_list)
display(HTML(results[["original_text", "perturbed_text"]].to_html(escape=False)))