During going through the training code (pytorch version, see below), I noticed that in function memoryEfficientLoss()
, the outputs
param is first rewrapped in a new Variable with the same name, and at the end of the function, the .grad
is unwrapped from outputs
and returned. I’m wondering why these operations are necessary?
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) #todo: why repack here?
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, opt.max_generator_batches)
targets_split = torch.split(targets, opt.max_generator_batches)
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
out_t = out_t.view(-1, out_t.size(2))
scores_t = generator(out_t)
loss_t = crit(scores_t, targ_t.view(-1))
pred_t = scores_t.max(1)[1]
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data #todo: why unpack here?
return loss, grad_output, num_correct