Source code for textattack.constraints.grammaticality.language_models.learning_to_write.adaptive_softmax

"""
AdaptiveSoftmax
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""


import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.functional import log_softmax

import textattack


[docs]class AdaptiveSoftmax(nn.Module): def __init__(self, input_size, cutoffs, scale_down=4): super().__init__() self.input_size = input_size self.cutoffs = cutoffs self.output_size = cutoffs[0] + len(cutoffs) - 1 self.head = nn.Linear(input_size, self.output_size) self.tail = nn.ModuleList() for i in range(len(cutoffs) - 1): seq = nn.Sequential( nn.Linear(input_size, input_size // scale_down, False), nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False), ) self.tail.append(seq)
[docs] def reset(self, init=0.1): self.head.weight.data.uniform_(-init, init) for tail in self.tail: for layer in tail: layer.weight.data.uniform_(-init, init)
[docs] def set_target(self, target): self.id = [] for i in range(len(self.cutoffs) - 1): mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1])) if mask.sum() > 0: self.id.append(Variable(mask.float().nonzero().squeeze(1))) else: self.id.append(None)
[docs] def forward(self, inp): assert len(inp.size()) == 2 output = [self.head(inp)] for i in range(len(self.id)): if self.id[i] is not None: output.append(self.tail[i](inp.index_select(0, self.id[i]))) else: output.append(None) return output
[docs] def log_prob(self, inp): assert len(inp.size()) == 2 head_out = self.head(inp) n = inp.size(0) prob = torch.zeros(n, self.cutoffs[-1]).to(textattack.shared.utils.device) lsm_head = log_softmax(head_out, dim=head_out.dim() - 1) prob.narrow(1, 0, self.output_size).add_( lsm_head.narrow(1, 0, self.output_size).data ) for i in range(len(self.tail)): pos = self.cutoffs[i] i_size = self.cutoffs[i + 1] - pos buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1) buff = buff.expand(n, i_size) temp = self.tail[i](inp) lsm_tail = log_softmax(temp, dim=temp.dim() - 1) prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data) return prob
[docs]class AdaptiveLoss(nn.Module): def __init__(self, cutoffs): super().__init__() self.cutoffs = cutoffs self.criterions = nn.ModuleList() for i in self.cutoffs: self.criterions.append(nn.CrossEntropyLoss(size_average=False))
[docs] def reset(self): for criterion in self.criterions: criterion.zero_grad()
[docs] def remap_target(self, target): new_target = [target.clone()] for i in range(len(self.cutoffs) - 1): mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1])) if mask.sum() > 0: new_target[0][mask] = self.cutoffs[0] + i new_target.append(target[mask].add(-self.cutoffs[i])) else: new_target.append(None) return new_target
[docs] def forward(self, inp, target): n = inp[0].size(0) target = self.remap_target(target.data) loss = 0 for i in range(len(inp)): if inp[i] is not None: assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1) criterion = self.criterions[i] loss += criterion(inp[i], Variable(target[i])) loss /= n return loss