TensorFlow and TextAttack

Open In Colab

View Source on GitHub

Please remember to run pip3 install textattack[tensorflow] in your notebook enviroment before the following codes:

Run textattack on a trained tensorflow model:

First: Training

The following is code for training a text classification model using TensorFlow (and on top of it, the Keras API). This comes from the Tensorflow documentation (see here).

This cell loads the IMDB dataset (using tensorflow_datasets, not datasets), initializes a simple classifier, and trains it using Keras.

[1]:
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print(
    "GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE"
)

train_data, test_data = tfds.load(
    name="imdb_reviews", split=["train", "test"], batch_size=-1, as_supervised=True
)

train_examples, train_labels = tfds.as_numpy(train_data)
test_examples, test_labels = tfds.as_numpy(test_data)

model = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(
    model, output_shape=[20], input_shape=[], dtype=tf.string, trainable=True
)
hub_layer(train_examples[:3])

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation="relu"))
model.add(tf.keras.layers.Dense(1))

model.summary()

x_val = train_examples[:10000]
partial_x_train = train_examples[10000:]

y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]

model.compile(
    optimizer="adam",
    loss=tf.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

history = model.fit(
    partial_x_train,
    partial_y_train,
    epochs=40,
    batch_size=512,
    validation_data=(x_val, y_val),
    verbose=1,
)
Version:  2.3.2
Eager mode:  True
Hub version:  0.12.0
GPU is NOT AVAILABLE
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
keras_layer (KerasLayer)     (None, 20)                400020
_________________________________________________________________
dense (Dense)                (None, 16)                336
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17
=================================================================
Total params: 400,373
Trainable params: 400,373
Non-trainable params: 0
_________________________________________________________________
Epoch 1/40
30/30 [==============================] - 2s 60ms/step - loss: 1.1498 - accuracy: 0.5131 - val_loss: 0.7321 - val_accuracy: 0.5698
Epoch 2/40
30/30 [==============================] - 1s 39ms/step - loss: 0.6914 - accuracy: 0.5994 - val_loss: 0.6636 - val_accuracy: 0.6008
Epoch 3/40
30/30 [==============================] - 1s 37ms/step - loss: 0.6423 - accuracy: 0.6219 - val_loss: 0.6330 - val_accuracy: 0.6367
Epoch 4/40
30/30 [==============================] - 1s 37ms/step - loss: 0.6107 - accuracy: 0.6524 - val_loss: 0.6043 - val_accuracy: 0.6602
Epoch 5/40
30/30 [==============================] - 1s 38ms/step - loss: 0.5759 - accuracy: 0.6832 - val_loss: 0.5736 - val_accuracy: 0.6943
Epoch 6/40
30/30 [==============================] - 1s 38ms/step - loss: 0.5390 - accuracy: 0.7149 - val_loss: 0.5391 - val_accuracy: 0.7141
Epoch 7/40
30/30 [==============================] - 1s 37ms/step - loss: 0.5004 - accuracy: 0.7480 - val_loss: 0.5068 - val_accuracy: 0.7393
Epoch 8/40
30/30 [==============================] - 1s 43ms/step - loss: 0.4632 - accuracy: 0.7733 - val_loss: 0.4773 - val_accuracy: 0.7634
Epoch 9/40
30/30 [==============================] - 1s 44ms/step - loss: 0.4292 - accuracy: 0.7986 - val_loss: 0.4502 - val_accuracy: 0.7739
Epoch 10/40
30/30 [==============================] - 1s 43ms/step - loss: 0.3963 - accuracy: 0.8205 - val_loss: 0.4271 - val_accuracy: 0.7863
Epoch 11/40
30/30 [==============================] - 1s 41ms/step - loss: 0.3664 - accuracy: 0.8357 - val_loss: 0.4048 - val_accuracy: 0.8109
Epoch 12/40
30/30 [==============================] - 1s 40ms/step - loss: 0.3396 - accuracy: 0.8533 - val_loss: 0.3864 - val_accuracy: 0.8146
Epoch 13/40
30/30 [==============================] - 1s 42ms/step - loss: 0.3147 - accuracy: 0.8681 - val_loss: 0.3699 - val_accuracy: 0.8347
Epoch 14/40
30/30 [==============================] - 1s 38ms/step - loss: 0.2918 - accuracy: 0.8788 - val_loss: 0.3563 - val_accuracy: 0.8356
Epoch 15/40
30/30 [==============================] - 1s 43ms/step - loss: 0.2717 - accuracy: 0.8881 - val_loss: 0.3452 - val_accuracy: 0.8393
Epoch 16/40
30/30 [==============================] - 1s 43ms/step - loss: 0.2540 - accuracy: 0.8965 - val_loss: 0.3347 - val_accuracy: 0.8493
Epoch 17/40
30/30 [==============================] - 1s 40ms/step - loss: 0.2379 - accuracy: 0.9045 - val_loss: 0.3272 - val_accuracy: 0.8533
Epoch 18/40
30/30 [==============================] - 1s 46ms/step - loss: 0.2230 - accuracy: 0.9114 - val_loss: 0.3204 - val_accuracy: 0.8616
Epoch 19/40
30/30 [==============================] - 1s 47ms/step - loss: 0.2103 - accuracy: 0.9177 - val_loss: 0.3181 - val_accuracy: 0.8562
Epoch 20/40
30/30 [==============================] - 1s 45ms/step - loss: 0.1977 - accuracy: 0.9243 - val_loss: 0.3119 - val_accuracy: 0.8660
Epoch 21/40
30/30 [==============================] - 1s 38ms/step - loss: 0.1861 - accuracy: 0.9303 - val_loss: 0.3093 - val_accuracy: 0.8729
Epoch 22/40
30/30 [==============================] - 1s 37ms/step - loss: 0.1759 - accuracy: 0.9337 - val_loss: 0.3075 - val_accuracy: 0.8704
Epoch 23/40
30/30 [==============================] - 1s 37ms/step - loss: 0.1661 - accuracy: 0.9393 - val_loss: 0.3061 - val_accuracy: 0.8719
Epoch 24/40
30/30 [==============================] - 1s 40ms/step - loss: 0.1564 - accuracy: 0.9439 - val_loss: 0.3077 - val_accuracy: 0.8745
Epoch 25/40
30/30 [==============================] - 2s 51ms/step - loss: 0.1461 - accuracy: 0.9475 - val_loss: 0.3077 - val_accuracy: 0.8739
Epoch 26/40
30/30 [==============================] - 1s 47ms/step - loss: 0.1363 - accuracy: 0.9524 - val_loss: 0.3098 - val_accuracy: 0.8714
Epoch 27/40
30/30 [==============================] - 1s 40ms/step - loss: 0.1283 - accuracy: 0.9551 - val_loss: 0.3113 - val_accuracy: 0.8727
Epoch 28/40
30/30 [==============================] - 1s 41ms/step - loss: 0.1226 - accuracy: 0.9577 - val_loss: 0.3142 - val_accuracy: 0.8746
Epoch 29/40
30/30 [==============================] - 1s 36ms/step - loss: 0.1132 - accuracy: 0.9623 - val_loss: 0.3166 - val_accuracy: 0.8733
Epoch 30/40
30/30 [==============================] - 1s 39ms/step - loss: 0.1059 - accuracy: 0.9661 - val_loss: 0.3207 - val_accuracy: 0.8699
Epoch 31/40
30/30 [==============================] - 1s 38ms/step - loss: 0.0994 - accuracy: 0.9683 - val_loss: 0.3240 - val_accuracy: 0.8692
Epoch 32/40
30/30 [==============================] - 1s 38ms/step - loss: 0.0945 - accuracy: 0.9711 - val_loss: 0.3285 - val_accuracy: 0.8687
Epoch 33/40
30/30 [==============================] - 1s 36ms/step - loss: 0.0877 - accuracy: 0.9744 - val_loss: 0.3327 - val_accuracy: 0.8694
Epoch 34/40
30/30 [==============================] - 1s 40ms/step - loss: 0.0823 - accuracy: 0.9761 - val_loss: 0.3414 - val_accuracy: 0.8658
Epoch 35/40
30/30 [==============================] - 1s 39ms/step - loss: 0.0773 - accuracy: 0.9785 - val_loss: 0.3423 - val_accuracy: 0.8712
Epoch 36/40
30/30 [==============================] - 1s 36ms/step - loss: 0.0723 - accuracy: 0.9805 - val_loss: 0.3481 - val_accuracy: 0.8674
Epoch 37/40
30/30 [==============================] - 1s 37ms/step - loss: 0.0681 - accuracy: 0.9823 - val_loss: 0.3597 - val_accuracy: 0.8624
Epoch 38/40
30/30 [==============================] - 1s 37ms/step - loss: 0.0639 - accuracy: 0.9831 - val_loss: 0.3589 - val_accuracy: 0.8691
Epoch 39/40
30/30 [==============================] - 1s 42ms/step - loss: 0.0595 - accuracy: 0.9857 - val_loss: 0.3650 - val_accuracy: 0.8720
Epoch 40/40
30/30 [==============================] - 1s 43ms/step - loss: 0.0558 - accuracy: 0.9868 - val_loss: 0.3711 - val_accuracy: 0.8696

Attacking

For each input, our classifier outputs a single number that indicates how positive or negative the model finds the input. For binary classification, TextAttack expects two numbers for each input (a score for each class, positive and negative). We have to post-process each output to fit this TextAttack format. To add this post-processing we need to implement a custom model wrapper class (instead of using the built-in textattack.models.wrappers.TensorFlowModelWrapper).

Each ModelWrapper must implement a single method, __call__, which takes a list of strings and returns a List, np.ndarray, or torch.Tensor of predictions.

[2]:
import numpy as np
import torch

from textattack.models.wrappers import ModelWrapper


class CustomTensorFlowModelWrapper(ModelWrapper):
    def __init__(self, model):
        self.model = model

    def __call__(self, text_input_list):
        text_array = np.array(text_input_list)
        preds = self.model(text_array).numpy()
        logits = torch.exp(-torch.tensor(preds))
        logits = 1 / (1 + logits)
        logits = logits.squeeze(dim=-1)
        # Since this model only has a single output (between 0 or 1),
        # we have to add the second dimension.
        final_preds = torch.stack((1 - logits, logits), dim=1)
        return final_preds

Let’s test our model wrapper out to make sure it can use our model to return predictions in the correct format.

[3]:
CustomTensorFlowModelWrapper(model)(["I hate you so much", "I love you"])
[3]:
tensor([[0.1409, 0.8591],
        [0.0213, 0.9787]])

Looks good! Now we can initialize our model wrapper with the model we trained and pass it to an instance of textattack.attack.Attack.

We’ll use the PWWSRen2019 recipe as our attack, and attack 10 samples.

[4]:
model_wrapper = CustomTensorFlowModelWrapper(model)

from textattack.datasets import HuggingFaceDataset
from textattack.attack_recipes import PWWSRen2019
from textattack import Attacker

dataset = HuggingFaceDataset("rotten_tomatoes", None, "test", shuffle=True)
attack = PWWSRen2019.build(model_wrapper)

attacker = Attacker(attack, dataset)
attacker.attack_dataset()
WARNING:datasets.builder:Using custom data configuration default
WARNING:datasets.builder:Reusing dataset rotten_tomatoes_movie_review (/p/qdata/jy2ma/.cache/textattack/datasets/rotten_tomatoes_movie_review/default/1.0.0/9c411f7ecd9f3045389de0d9ce984061a1056507703d2e3183b1ac1a90816e4d)
textattack: Loading datasets dataset rotten_tomatoes, split test.
textattack: Unknown if model of class <class 'tensorflow.python.keras.engine.sequential.Sequential'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
[Succeeded / Failed / Skipped / Total] 2 / 0 / 3 / 5:  50%|█████     | 5/10 [00:00<00:00, 43.42it/s]
Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  weighted-saliency
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapWordNet
  (constraints):
    (0): RepeatModification
    (1): StopwordModification
  (is_black_box):  True
)

--------------------------------------------- Result 1 ---------------------------------------------
Negative (90%) --> [SKIPPED]

lovingly photographed in the manner of a golden book sprung to life , stuart little 2 manages sweetness largely without stickiness .


--------------------------------------------- Result 2 ---------------------------------------------
Positive (52%) --> Negative (97%)

consistently clever and suspenseful .

consistently clever and cliff-hanging .


--------------------------------------------- Result 3 ---------------------------------------------
Positive (89%) --> Negative (86%)

it's like a " big chill " reunion of the baader-meinhof gang , only these guys are more harmless pranksters than political activists .

it's like a " big chill " reunion of the baader-meinhof gang , only these roast are more harmless pranksters than political activists .


--------------------------------------------- Result 4 ---------------------------------------------
Negative (60%) --> [SKIPPED]

the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill .


--------------------------------------------- Result 5 ---------------------------------------------
Negative (99%) --> [SKIPPED]

red dragon " never cuts corners .


[Succeeded / Failed / Skipped / Total] 4 / 0 / 3 / 7:  70%|███████   | 7/10 [00:00<00:00, 18.97it/s]
--------------------------------------------- Result 6 ---------------------------------------------
Positive (99%) --> Negative (85%)

fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense .

fresnadillo has something serious to say about the manner in which exuberant chance can distort our perspective and throw us off the path of ripe sense .


--------------------------------------------- Result 7 ---------------------------------------------
Positive (99%) --> Negative (73%)

throws in enough clever and unexpected twists to make the formula feel fresh .

flip in enough clever and unexpected construction to make the formula feel fresh .


[Succeeded / Failed / Skipped / Total] 6 / 0 / 4 / 10: 100%|██████████| 10/10 [00:00<00:00, 17.90it/s]
--------------------------------------------- Result 8 ---------------------------------------------
Positive (96%) --> Negative (93%)

weighty and ponderous but every bit as filling as the treat of the title .

weighty and ponderous but every bite as filling as the cover of the title .


--------------------------------------------- Result 9 ---------------------------------------------
Positive (84%) --> Negative (70%)

a real audience-pleaser that will strike a chord with anyone who's ever waited in a doctor's office , emergency room , hospital bed or insurance company office .

a material audience-pleaser that will strike a chord with anyone who's ever waited in a doctor's office , emergency room , hospital screw or insurance company office .


--------------------------------------------- Result 10 ---------------------------------------------
Negative (99%) --> [SKIPPED]

generates an enormous feeling of empathy for its characters .



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 6      |
| Number of failed attacks:     | 0      |
| Number of skipped attacks:    | 4      |
| Original accuracy:            | 60.0%  |
| Accuracy under attack:        | 0.0%   |
| Attack success rate:          | 100.0% |
| Average perturbed word %:     | 13.2%  |
| Average num. words per input: | 15.4   |
| Avg num queries:              | 139.0  |
+-------------------------------+--------+

[4]:
[<textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a3e50>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3c70>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3490>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a3fd0>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a39d0>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f7475903100>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f7475284460>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3910>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3790>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f7489297c40>]

Conclusion

Looks good! We successfully loaded a model, adapted it for TextAttack’s ModelWrapper, and used that object in an attack. This is basically how you would adapt any model, using TensorFlow or any other library, for use with TextAttack.