Batched seq2seq in RNNs

I’m having difficulties getting my head around how to efficiently do batched training in pre-transformer (RNN) networks.

On the decoder side, I think we would typically see something like this (for simplicity’s sake, we’ll use a GRU which does not have a cell state but only a hidden state). Please correct me if this wrong!

# first input is beginning-of-sentence token
dec_in = torch.tensor([[BOS]])
# first hidden state is the last encoder hidden state
dec_h = enc_h

# iterate for the expected length
for idx in range(target_len):
    # pass input and hidden state to decoder
    dec_out, dec_h = decoder(dec_in, dec_h)

    # get highest scoring prob and its corresponding word idx
    topv, topi = dec_out.topk(1)
    # use best predicted word as input for next word
    dec_in = topi.squeeze().detach()
    
    # compare the predicted idx with the actual idx at this position
    loss += criterion(dec_in, target_idxs[idx])
    # stop generating new words when we hit end-of-sentence token
    if dec_in.item() == EOS:
        break

My question is how this works in a batched setting, particularly because the position of the EOS token differs between items in a batch. Should we add an additional inner loop to check for each returned item in the batch how it performs and whether it has reached an EOS token, and if so - remove it from the batch?

My assumption is that we just do not check for EOS. The problem that I have with that is that you will then still count “finished” sentences (that had EOS predicted) towards loss. So you will penalize the model if it predicts anything but PAD after EOS. I am not sure if that is what you’d want or need.

All helpful comments or additional resources are welcome.

Usually we simply mask the cross entropy for padding positions, so that out of range positions do not contribute to the loss.

In PyTorch, you can set the ignore_index argument in the loss objects, for example:

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

1 Like

Surely that must provide quite some overhead, especially for larger batches. Even for “finished” sentences you still do a forward pass?

Looking at the code snippet that I posted, I seem to have made a mistake: I only feed the previously predicted token to the LSTM. Should we not feed all previously t-1 predicted tokens to the LSTM? Or does the hidden state already capture that information?

During training, batches are usually prepared with sentences having the same (or similar) length. So in practice this is not an issue.

For RNN, you only feed the last step and the states capture the context of the previous steps.

Yes, that is a good point about sorted batches. Hadn’t thought about that, thanks.

After thinking about it more, I think I understood it. It makes sense that in an RNN the previous tokens are not necessary because the last states already contain the information of those. If we did not have access to the previous states, we would need to pass all tokens through the RNN again, but considering we always have access to the latest states, we can just use those. So basically we can just use an RNNCell rather than an RNNLayer.

As a final question: the process above works for RNN and GRU, but LSTM has an additional hidden state (cell state). The question that I am trying to get my head around is, what is the semantic difference here - what do they represent differently? I know the code, but not what the concepts are supposed to intuitively mean. Related to that, should the encoder state be used to initialize the hidden state or the cell state, or both? (I’d think hidden state but I am not sure why.)

This blog post has some information:

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.

So the cell state can carry information over longer distances.

OpenNMT typically copies both the hidden and cell states to the decoder. I’ve seen other implementations that only copy the hidden state. I’m not sure which one is best.

Thanks for the information, I’ll dig into it later when I find the time!

My question comes from some things that I wanted to try, namely using pretrained sentence embeddings instead of an encoder. In an RNN/GRU this is straightforward, but in an LSTM I wasn’t sure whether the the s-embedding should be used to init the hidden state or the cell state. I’ve got a good response on the original topic over at the PyTorch Forums, but your help for the more detailed MT-questions was very helpful, so thanks!