I’m having difficulties getting my head around how to efficiently do batched training in pre-transformer (RNN) networks.
On the decoder side, I think we would typically see something like this (for simplicity’s sake, we’ll use a GRU which does not have a cell state but only a hidden state). Please correct me if this wrong!
# first input is beginning-of-sentence token dec_in = torch.tensor([[BOS]]) # first hidden state is the last encoder hidden state dec_h = enc_h # iterate for the expected length for idx in range(target_len): # pass input and hidden state to decoder dec_out, dec_h = decoder(dec_in, dec_h) # get highest scoring prob and its corresponding word idx topv, topi = dec_out.topk(1) # use best predicted word as input for next word dec_in = topi.squeeze().detach() # compare the predicted idx with the actual idx at this position loss += criterion(dec_in, target_idxs[idx]) # stop generating new words when we hit end-of-sentence token if dec_in.item() == EOS: break
My question is how this works in a batched setting, particularly because the position of the EOS token differs between items in a batch. Should we add an additional inner loop to check for each returned item in the batch how it performs and whether it has reached an EOS token, and if so - remove it from the batch?
My assumption is that we just do not check for EOS. The problem that I have with that is that you will then still count “finished” sentences (that had EOS predicted) towards loss. So you will penalize the model if it predicts anything but PAD after EOS. I am not sure if that is what you’d want or need.
All helpful comments or additional resources are welcome.