Source code for textattack.shared.utils.tensor

import numpy as np
import torch

[docs]def batch_model_predict(model_predict, inputs, batch_size=32): """Runs prediction on iterable ``inputs`` using batch size ``batch_size``. Aggregates all predictions into an ``np.ndarray``. """ outputs = [] i = 0 while i < len(inputs): batch = inputs[i : i + batch_size] batch_preds = model_predict(batch) # Some seq-to-seq models will return a single string as a prediction # for a single-string list. Wrap these in a list. if isinstance(batch_preds, str): batch_preds = [batch_preds] # Get PyTorch tensors off of other devices. if isinstance(batch_preds, torch.Tensor): batch_preds = batch_preds.cpu() # Cast all predictions iterables to ``np.ndarray`` types. if not isinstance(batch_preds, np.ndarray): batch_preds = np.array(batch_preds) outputs.append(batch_preds) i += batch_size return np.concatenate(outputs, axis=0)