Source code for textattack.models.helpers.word_cnn_for_classification

"""
Word CNN for Classification
---------------------------------------------------------------------

"""

import json
import os

import torch
from torch import nn as nn
from torch.nn import functional as F

import textattack
from textattack.model_args import TEXTATTACK_MODELS
from textattack.models.helpers import GloveEmbeddingLayer
from textattack.models.helpers.utils import load_cached_state_dict
from textattack.shared import utils


[docs]class WordCNNForClassification(nn.Module): """A convolutional neural network for text classification. We use different versions of this network to pretrain models for text classification. """ def __init__( self, hidden_size=150, dropout=0.3, num_labels=2, max_seq_length=128, model_path=None, emb_layer_trainable=True, ): super().__init__() self._config = { "architectures": "WordCNNForClassification", "hidden_size": hidden_size, "dropout": dropout, "num_labels": num_labels, "max_seq_length": max_seq_length, "model_path": model_path, "emb_layer_trainable": emb_layer_trainable, } self.drop = nn.Dropout(dropout) self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) self.word2id = self.emb_layer.word2id self.encoder = CNNTextLayer( self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size ) d_out = 3 * hidden_size self.out = nn.Linear(d_out, num_labels) self.tokenizer = textattack.models.tokenizers.GloveTokenizer( word_id_map=self.word2id, unk_token_id=self.emb_layer.oovid, pad_token_id=self.emb_layer.padid, max_length=max_seq_length, ) if model_path is not None: self.load_from_disk(model_path) self.eval()
[docs] def load_from_disk(self, model_path): # TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`. import warnings warnings.warn( "`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.", DeprecationWarning, stacklevel=2, ) self.load_state_dict(load_cached_state_dict(model_path)) self.eval()
[docs] def save_pretrained(self, output_path): if not os.path.exists(output_path): os.makedirs(output_path) state_dict = {k: v.cpu() for k, v in self.state_dict().items()} torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin")) with open(os.path.join(output_path, "config.json"), "w") as f: json.dump(self._config, f)
[docs] @classmethod def from_pretrained(cls, name_or_path): """Load trained Word CNN model by name or from path. Args: name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`. Returns: :class:`~textattack.models.helpers.WordCNNForClassification` model """ if name_or_path in TEXTATTACK_MODELS: path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path]) else: path = name_or_path config_path = os.path.join(path, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: # Default config config = { "architectures": "WordCNNForClassification", "hidden_size": 150, "dropout": 0.3, "num_labels": 2, "max_seq_length": 128, "model_path": None, "emb_layer_trainable": True, } del config["architectures"] model = cls(**config) state_dict = load_cached_state_dict(path) model.load_state_dict(state_dict) return model
[docs] def forward(self, _input): emb = self.emb_layer(_input) emb = self.drop(emb) output = self.encoder(emb) output = self.drop(output) pred = self.out(output) return pred
[docs] def get_input_embeddings(self): return self.emb_layer.embedding
[docs]class CNNTextLayer(nn.Module): def __init__(self, n_in, widths=[3, 4, 5], filters=100): super().__init__() Ci = 1 Co = filters h = n_in self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths])
[docs] def forward(self, x): x = x.unsqueeze(1) # (batch, Ci, len, d) x = [ F.relu(conv(x)).squeeze(3) for conv in self.convs1 ] # [(batch, Co, len), ...] x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N,Co), ...] x = torch.cat(x, 1) return x