Please remember to run pip3 install textattack[tensorflow] in your notebook enviroment before the following codes:

Explain Attacking BERT models using CAptum

Captum is a PyTorch library to explain neural networks Here we show a minimal example using Captum to explain BERT models from TextAttack

Open In Colab

View Source on GitHub

[1]:
import torch
from copy import deepcopy
[2]:
from textattack.datasets import HuggingFaceDataset
from textattack.models.wrappers import HuggingFaceModelWrapper
from textattack.models.wrappers import ModelWrapper
from transformers import AutoModelForSequenceClassification, AutoTokenizer
[3]:
# Optional: Install dependency CAptum
!pip3 install captum
Requirement already satisfied: captum in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (0.4.0)
Requirement already satisfied: torch>=1.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from captum) (1.6.0)
Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from captum) (1.18.5)
Requirement already satisfied: matplotlib in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from captum) (3.3.3)
Requirement already satisfied: future in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from torch>=1.2->captum) (0.18.2)
Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib->captum) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib->captum) (1.3.1)
Requirement already satisfied: python-dateutil>=2.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib->captum) (2.8.1)
Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib->captum) (8.0.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib->captum) (2.4.7)
Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->captum) (1.15.0)
WARNING: You are using pip version 20.1.1; however, version 21.2.1 is available.
You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.7/bin/python3.7 -m pip install --upgrade pip' command.
[4]:
from captum.attr import (
    IntegratedGradients,
    LayerConductance,
    LayerIntegratedGradients,
    LayerDeepLiftShap,
    InternalInfluence,
    LayerGradientXActivation,
)
from captum.attr import visualization as viz
[5]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device)
cpu
[6]:
dataset = HuggingFaceDataset("ag_news", None, "train")
original_model = AutoModelForSequenceClassification.from_pretrained(
    "textattack/bert-base-uncased-ag-news"
)
original_tokenizer = AutoTokenizer.from_pretrained(
    "textattack/bert-base-uncased-ag-news"
)
model = HuggingFaceModelWrapper(original_model, original_tokenizer)
Using custom data configuration default
Reusing dataset ag_news (/Users/ccy/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
textattack: Loading datasets dataset ag_news, split train.
[7]:
def get_text(tokenizer, input_ids, token_type_ids, attention_mask):
    list_of_text = []
    number = input_ids.size()[0]
    for i in range(number):
        ii = input_ids[i,].cpu().numpy()
        tt = token_type_ids[i,]
        am = attention_mask[i,]
        txt = tokenizer.decode(ii, skip_special_tokens=True)
        list_of_text.append(txt)
    return list_of_text


sel = 2
batch_encoded = model.tokenizer(
    [dataset[i][0]["text"] for i in range(sel)], padding=True, return_tensors="pt"
)
batch_encoded.to(device)
labels = [dataset[i][1] for i in range(sel)]

clone = deepcopy(model)
clone.model.to(device)


def calculate(input_ids, token_type_ids, attention_mask):
    # convert back to list of text
    return clone.model(input_ids, token_type_ids, attention_mask)[0]


# x = calculate(**batch_encoded)

lig = LayerIntegratedGradients(calculate, clone.model.bert.embeddings)
# lig = InternalInfluence(calculate, clone.model.bert.embeddings)
# lig = LayerGradientXActivation(calculate, clone.model.bert.embeddings)

bsl = torch.zeros(batch_encoded["input_ids"].size()).type(torch.LongTensor).to(device)
labels = torch.tensor(labels).to(device)

attributions, delta = lig.attribute(
    inputs=batch_encoded["input_ids"],
    baselines=bsl,
    additional_forward_args=(
        batch_encoded["token_type_ids"],
        batch_encoded["attention_mask"],
    ),
    n_steps=10,
    target=labels,
    return_convergence_delta=True,
)
atts = attributions.sum(dim=-1).squeeze(0)
atts = atts / torch.norm(atts)
[8]:
atts = attributions.sum(dim=-1).squeeze(0)
atts = atts / torch.norm(atts)
[9]:
from textattack.attack_recipes import PWWSRen2019

attack = PWWSRen2019.build(model)
textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
[10]:
from textattack import Attacker

attacker = Attacker(attack, dataset)
attacker.attack_dataset()
more-to-come:

class:

stderr

0%| | 0/10 [00:00&lt;?, ?it/s]

</pre>

0%| | 0/10 [00:00<?, ?it/s]

end{sphinxVerbatim}

0%| | 0/10 [00:00<?, ?it/s]

Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  weighted-saliency
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapWordNet
  (constraints):
    (0): RepeatModification
    (1): StopwordModification
  (is_black_box):  True
)

[Succeeded / Failed / Skipped / Total] 1 / 0 / 0 / 1:  10%|██████████▌                                                                                               | 1/10 [11:02<1:39:26, 662.96s/it]
--------------------------------------------- Result 1 ---------------------------------------------
[[Business (96%)]] --> [[Sci/tech (68%)]]

Wall St. [[Bears]] Claw Back Into the [[Black]] (Reuters) Reuters - Short-sellers, Wall Street's [[dwindling]]\[[band]] of ultra-cynics, are [[seeing]] [[green]] again.

Wall St. [[suffer]] Claw Back Into the [[lightlessness]] (Reuters) Reuters - Short-sellers, Wall Street's [[dwindle]]\[[isthmus]] of ultra-cynics, are [[examine]] [[greenish]] again.


[Succeeded / Failed / Skipped / Total] 2 / 0 / 0 / 2:  20%|█████████████████████▏                                                                                    | 2/10 [21:10<1:24:40, 635.08s/it]
--------------------------------------------- Result 2 ---------------------------------------------
[[Business (100%)]] --> [[Sci/tech (50%)]]

Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private [[investment]] firm Carlyle Group,\which has a reputation for [[making]] well-timed and occasionally\controversial plays in the [[defense]] industry, has quietly [[placed]]\its bets on another part of the market.

Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private [[investiture]] firm Carlyle Group,\which has a reputation for [[ca-ca]] well-timed and occasionally\controversial plays in the [[denial]] industry, has quietly [[site]]\its bets on another part of the market.


[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3:  30%|███████████████████████████████▊                                                                          | 3/10 [28:14<1:05:52, 564.68s/it]
--------------------------------------------- Result 3 ---------------------------------------------
[[Business (100%)]] --> [[Sci/tech (100%)]]

Oil and Economy [[Cloud]] Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for [[earnings]] are expected to\hang over the [[stock]] market next week during the depth of the\summer doldrums.

Oil and Economy [[swarm]] Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for [[lucre]] are expected to\hang over the [[gillyflower]] market next week during the depth of the\summer doldrums.


[Succeeded / Failed / Skipped / Total] 4 / 0 / 0 / 4:  40%|███████████████████████████████████████████▏                                                                | 4/10 [36:03<54:05, 541.00s/it]
--------------------------------------------- Result 4 ---------------------------------------------
[[Business (78%)]] --> [[World (58%)]]

Iraq [[Halts]] Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\[[flows]] from the main pipeline in southern [[Iraq]] after\intelligence showed a rebel militia could [[strike]]\infrastructure, an oil official said on Saturday.

Iraq [[kibosh]] Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\[[hang]] from the main pipeline in southern [[Irak]] after\intelligence showed a rebel militia could [[fall]]\infrastructure, an oil official said on Saturday.


[Succeeded / Failed / Skipped / Total] 5 / 0 / 0 / 5:  50%|██████████████████████████████████████████████████████                                                      | 5/10 [42:59<42:59, 515.83s/it]
--------------------------------------------- Result 5 ---------------------------------------------
[[Business (99%)]] --> [[World (82%)]]

Oil prices soar to all-time record, posing new menace to [[US]] [[economy]] (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.

Oil prices soar to all-time record, posing new menace to [[uranium]] [[thriftiness]] (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.


[Succeeded / Failed / Skipped / Total] 6 / 0 / 0 / 6:  60%|████████████████████████████████████████████████████████████████▊                                           | 6/10 [56:56<37:57, 569.40s/it]
--------------------------------------------- Result 6 ---------------------------------------------
[[Business (100%)]] --> [[Sci/tech (55%)]]

[[Stocks]] [[End]] Up, But Near Year Lows (Reuters) Reuters - [[Stocks]] [[ended]] slightly higher on Friday\but [[stayed]] [[near]] [[lows]] for the year as [[oil]] prices surged past  #36;[[46]]\a [[barrel]], [[offsetting]] a [[positive]] [[outlook]] from computer [[maker]]\Dell Inc. (DELL.O)

[[stock]] [[terminate]] Up, But Near Year Lows (Reuters) Reuters - [[inventory]] [[terminate]] slightly higher on Friday\but [[continue]] [[virtually]] [[moo]] for the year as [[embrocate]] prices surged past  #36;[[xlvi]]\a [[drum]], [[cancel]] a [[incontrovertible]] [[mindset]] from computer [[Creator]]\Dell Inc. (DELL.O)


[Succeeded / Failed / Skipped / Total] 7 / 0 / 0 / 7:  70%|██████████████████████████████████████████████████████████████████████████▏                               | 7/10 [1:06:29<28:29, 569.87s/it]
--------------------------------------------- Result 7 ---------------------------------------------
[[Business (100%)]] --> [[World (76%)]]

Money Funds [[Fell]] in [[Latest]] [[Week]] (AP) AP - [[Assets]] of the nation's retail money [[market]] [[mutual]] funds [[fell]] by  #36;1.17 [[billion]] in the latest week to  #36;849.98 [[trillion]], the Investment Company [[Institute]] [[said]] Thursday.

Money Funds [[hide]] in [[up-to-the-minute]] [[hebdomad]] (AP) AP - [[asset]] of the nation's retail money [[commercialise]] [[common]] funds [[cruel]] by  #36;1.17 [[gazillion]] in the latest week to  #36;849.98 [[1000000000000]], the Investment Company [[constitute]] [[state]] Thursday.


[Succeeded / Failed / Skipped / Total] 8 / 0 / 0 / 8:  80%|████████████████████████████████████████████████████████████████████████████████████▊                     | 8/10 [1:17:31<19:22, 581.46s/it]
--------------------------------------------- Result 8 ---------------------------------------------
[[Business (100%)]] --> [[World (76%)]]

[[Fed]] [[minutes]] [[show]] [[dissent]] over inflation (USATODAY.com) USATODAY.com - Retail sales [[bounced]] [[back]] a bit in July, and [[new]] claims for jobless [[benefits]] fell [[last]] [[week]], the government said Thursday, indicating the economy is [[improving]] from a midsummer slump.

[[eat]] [[hour]] [[testify]] [[protest]] over inflation (USATODAY.com) USATODAY.com - Retail sales [[bound]] [[hind]] a bit in July, and [[young]] claims for jobless [[welfare]] fell [[death]] [[workweek]], the government said Thursday, indicating the economy is [[amend]] from a midsummer slump.


[Succeeded / Failed / Skipped / Total] 9 / 0 / 0 / 9:  90%|███████████████████████████████████████████████████████████████████████████████████████████████▍          | 9/10 [1:28:33<09:50, 590.35s/it]
--------------------------------------------- Result 9 ---------------------------------------------
[[Business (100%)]] --> [[Sci/tech (100%)]]

Safety [[Net]] (Forbes.com) Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of  #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, "buying insurance was the furthest thing from my mind," says Riley.

Safety [[cyberspace]] (Forbes.com) Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of  #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, "buying insurance was the furthest thing from my mind," says Riley.


[Succeeded / Failed / Skipped / Total] 10 / 0 / 0 / 10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [1:34:51<00:00, 569.19s/it]
--------------------------------------------- Result 10 ---------------------------------------------
[[Business (100%)]] --> [[World (71%)]]

Wall St. Bears [[Claw]] Back Into the Black  [[NEW]] YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again.

Wall St. Bears [[chela]] Back Into the Black  [[novel]] YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again.



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 10     |
| Number of failed attacks:     | 0      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 0.0%   |
| Attack success rate:          | 100.0% |
| Average perturbed word %:     | 16.33% |
| Average num. words per input: | 38.5   |
| Avg num queries:              | 331.8  |
+-------------------------------+--------+


[10]:
[<textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d869f190>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4f18a2d90>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d943bf10>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d803a890>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4e2c9d6d0>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d63f17d0>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4f2a46950>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d6f3f710>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4de90bd10>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7fd4d470fc90>]