"""
AdaptiveSoftmax
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.functional import log_softmax
import textattack
[docs]class AdaptiveSoftmax(nn.Module):
def __init__(self, input_size, cutoffs, scale_down=4):
super().__init__()
self.input_size = input_size
self.cutoffs = cutoffs
self.output_size = cutoffs[0] + len(cutoffs) - 1
self.head = nn.Linear(input_size, self.output_size)
self.tail = nn.ModuleList()
for i in range(len(cutoffs) - 1):
seq = nn.Sequential(
nn.Linear(input_size, input_size // scale_down, False),
nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False),
)
self.tail.append(seq)
[docs] def reset(self, init=0.1):
self.head.weight.data.uniform_(-init, init)
for tail in self.tail:
for layer in tail:
layer.weight.data.uniform_(-init, init)
[docs] def set_target(self, target):
self.id = []
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
else:
self.id.append(None)
[docs] def forward(self, inp):
assert len(inp.size()) == 2
output = [self.head(inp)]
for i in range(len(self.id)):
if self.id[i] is not None:
output.append(self.tail[i](inp.index_select(0, self.id[i])))
else:
output.append(None)
return output
[docs] def log_prob(self, inp):
assert len(inp.size()) == 2
head_out = self.head(inp)
n = inp.size(0)
prob = torch.zeros(n, self.cutoffs[-1]).to(textattack.shared.utils.device)
lsm_head = log_softmax(head_out, dim=head_out.dim() - 1)
prob.narrow(1, 0, self.output_size).add_(
lsm_head.narrow(1, 0, self.output_size).data
)
for i in range(len(self.tail)):
pos = self.cutoffs[i]
i_size = self.cutoffs[i + 1] - pos
buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1)
buff = buff.expand(n, i_size)
temp = self.tail[i](inp)
lsm_tail = log_softmax(temp, dim=temp.dim() - 1)
prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data)
return prob
[docs]class AdaptiveLoss(nn.Module):
def __init__(self, cutoffs):
super().__init__()
self.cutoffs = cutoffs
self.criterions = nn.ModuleList()
for i in self.cutoffs:
self.criterions.append(nn.CrossEntropyLoss(size_average=False))
[docs] def reset(self):
for criterion in self.criterions:
criterion.zero_grad()
[docs] def remap_target(self, target):
new_target = [target.clone()]
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
new_target[0][mask] = self.cutoffs[0] + i
new_target.append(target[mask].add(-self.cutoffs[i]))
else:
new_target.append(None)
return new_target
[docs] def forward(self, inp, target):
n = inp[0].size(0)
target = self.remap_target(target.data)
loss = 0
for i in range(len(inp)):
if inp[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1)
criterion = self.criterions[i]
loss += criterion(inp[i], Variable(target[i]))
loss /= n
return loss