Source code for textattack.models.wrappers.sklearn_model_wrapper

"""
scikit-learn Model Wrapper
--------------------------
"""

import pandas as pd

from .model_wrapper import ModelWrapper


[docs]class SklearnModelWrapper(ModelWrapper): """Loads a scikit-learn model and tokenizer (tokenizer implements `transform` and model implements `predict_proba`). May need to be extended and modified for different types of tokenizers. """ def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def __call__(self, text_input_list, batch_size=None): encoded_text_matrix = self.tokenizer.transform(text_input_list).toarray() tokenized_text_df = pd.DataFrame( encoded_text_matrix, columns=self.tokenizer.get_feature_names() ) return self.model.predict_proba(tokenized_text_df)
[docs] def get_grad(self, text_input): raise NotImplementedError()