Source code for textattack.models.tokenizers.t5_tokenizer

"""
T5 Tokenizer
---------------------------------------------------------------------

"""

import transformers


[docs]class T5Tokenizer: """Uses the T5 tokenizer to convert an input for processing. For more information, please see the T5 paper, "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". Appendix D contains information about the various tasks supported by T5. Supports the following modes: * summarization: summarize English text * english_to_german: translate English to German * english_to_french: translate English to French * english_to_romanian: translate English to Romanian """ def __init__(self, mode="english_to_german", max_length=64): if mode == "english_to_german": self.tokenization_prefix = "translate English to German: " elif mode == "english_to_french": self.tokenization_prefix = "translate English to French: " elif mode == "english_to_romanian": self.tokenization_prefix = "translate English to Romanian: " elif mode == "summarization": self.tokenization_prefix = "summarize: " else: raise ValueError(f"Invalid t5 tokenizer mode {mode}.") self.tokenizer = transformers.AutoTokenizer.from_pretrained( "t5-base", use_fast=True ) self.max_length = max_length def __call__(self, text, *args, **kwargs): """ Args: text (:obj:`str`, :obj:`List[str]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings. """ assert isinstance(text, str) or ( isinstance(text, (list, tuple)) and (len(text) == 0 or isinstance(text[0], str)) ), "`text` must be a string or a list of strings." if isinstance(text, str): text = self.tokenization_prefix + text else: for i in range(len(text)): text[i] = self.tokenization_prefix + text[i] return self.tokenizer(text, *args, max_length=self.max_length, **kwargs)
[docs] def decode(self, ids): """Converts IDs (typically generated by the model) back to a string.""" return self.tokenizer.decode(ids)