During going through the training code (pytorch version, see below), I noticed that in function memoryEfficientLoss(), the outputs param is first rewrapped in a new Variable with the same name, and at the end of the function, the .grad is unwrapped from outputs and returned. I’m wondering why these operations are necessary?
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) #todo: why repack here?
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, opt.max_generator_batches)
targets_split = torch.split(targets, opt.max_generator_batches)
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
out_t = out_t.view(-1, out_t.size(2))
scores_t = generator(out_t)
loss_t = crit(scores_t, targ_t.view(-1))
pred_t = scores_t.max(1)[1]
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data #todo: why unpack here?
return loss, grad_output, num_correct