Source code for textattack.datasets.huggingface_dataset

"""

HuggingFaceDataset Class
=========================

TextAttack allows users to provide their own dataset or load from HuggingFace.


"""

import collections

import datasets

import textattack

from .dataset import Dataset


def _cb(s):
    """Colors some text blue for printing to the terminal."""
    return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")


[docs]def get_datasets_dataset_columns(dataset): """Common schemas for datasets found in dataset hub.""" schema = set(dataset.column_names) if {"premise", "hypothesis", "label"} <= schema: input_columns = ("premise", "hypothesis") output_column = "label" elif {"question", "sentence", "label"} <= schema: input_columns = ("question", "sentence") output_column = "label" elif {"sentence1", "sentence2", "label"} <= schema: input_columns = ("sentence1", "sentence2") output_column = "label" elif {"question1", "question2", "label"} <= schema: input_columns = ("question1", "question2") output_column = "label" elif {"question", "sentence", "label"} <= schema: input_columns = ("question", "sentence") output_column = "label" elif {"context", "question", "title", "answers"} <= schema: # Common schema for SQUAD dataset input_columns = ("title", "context", "question") output_column = "answers" elif {"text", "label"} <= schema: input_columns = ("text",) output_column = "label" elif {"sentence", "label"} <= schema: input_columns = ("sentence",) output_column = "label" elif {"document", "summary"} <= schema: input_columns = ("document",) output_column = "summary" elif {"content", "summary"} <= schema: input_columns = ("content",) output_column = "summary" elif {"label", "review"} <= schema: input_columns = ("review",) output_column = "label" else: raise ValueError( f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument." ) return input_columns, output_column
[docs]class HuggingFaceDataset(Dataset): """Loads a dataset from 🤗 Datasets and prepares it as a TextAttack dataset. Args: name_or_dataset (:obj:`Union[str, datasets.Dataset]`): The dataset name as :obj:`str` or actual :obj:`datasets.Dataset` object. If it's your custom :obj:`datasets.Dataset` object, please pass the input and output columns via :obj:`dataset_columns` argument. subset (:obj:`str`, `optional`, defaults to :obj:`None`): The subset of the main dataset. Dataset will be loaded as :obj:`datasets.load_dataset(name, subset)`. split (:obj:`str`, `optional`, defaults to :obj:`"train"`): The split of the dataset. dataset_columns (:obj:`tuple(list[str], str))`, `optional`, defaults to :obj:`None`): Pair of :obj:`list[str]` representing list of input column names (e.g. :obj:`["premise", "hypothesis"]`) and :obj:`str` representing the output column name (e.g. :obj:`label`). If not set, we will try to automatically determine column names from known designs. label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`): Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement. For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements. Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`). label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`): List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset). If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets. output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`): Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1. Some datasets are regression tasks, in which case this is necessary. shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset. .. note:: Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack. """ def __init__( self, name_or_dataset, subset=None, split="train", dataset_columns=None, label_map=None, label_names=None, output_scale_factor=None, shuffle=False, ): if isinstance(name_or_dataset, datasets.Dataset): self._dataset = name_or_dataset else: self._name = name_or_dataset self._subset = subset self._dataset = datasets.load_dataset(self._name, subset)[split] subset_print_str = f", subset {_cb(subset)}" if subset else "" textattack.shared.logger.info( f"Loading {_cb('datasets')} dataset {_cb(self._name)}{subset_print_str}, split {_cb(split)}." ) # Input/output column order, like (('premise', 'hypothesis'), 'label') ( self.input_columns, self.output_column, ) = dataset_columns or get_datasets_dataset_columns(self._dataset) if not isinstance(self.input_columns, (list, tuple)): raise ValueError( "First element of `dataset_columns` must be a list or a tuple." ) self.label_map = label_map self.output_scale_factor = output_scale_factor if label_names: self.label_names = label_names else: try: self.label_names = self._dataset.features[self.output_column].names except (KeyError, AttributeError): # This happens when the dataset doesn't have 'features' or a 'label' column. self.label_names = None # If labels are remapped, the label names have to be remapped as well. if self.label_names and label_map: self.label_names = [ self.label_names[self.label_map[i]] for i in self.label_map ] self.shuffled = shuffle if shuffle: self._dataset.shuffle() def _format_as_dict(self, example): input_dict = collections.OrderedDict( [(c, example[c]) for c in self.input_columns] ) output = example[self.output_column] if self.label_map: output = self.label_map[output] if self.output_scale_factor: output = output / self.output_scale_factor return (input_dict, output)
[docs] def filter_by_labels_(self, labels_to_keep): """Filter items by their labels for classification datasets. Performs in-place filtering. Args: labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`): Set, tuple, list, or iterable of integers representing labels. """ if not isinstance(labels_to_keep, set): labels_to_keep = set(labels_to_keep) self._dataset = self._dataset.filter( lambda x: x[self.output_column] in labels_to_keep )
[docs] def __getitem__(self, i): """Return i-th sample.""" if isinstance(i, int): return self._format_as_dict(self._dataset[i]) else: # `idx` could be a slice or an integer. if it's a slice, # return the formatted version of the proper slice of the list return [ self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop) ]
[docs] def shuffle(self): self._dataset.shuffle() self.shuffled = True