Why exclude last target from inputs?


(Nikhil Verma) #1

In the pytorch version of OpenNMT, the last target is excluded from the inputs to the decoder. What is the reason behind doing so?

At every iteration i during decoding, let the decoder cell dec[i] receive an input inp[i] to produce an output out[i]. I earllier suspected that tgt[-1] is not needed because we’re feeding in tgt[i-1] as inp[i] at every iteration i. So dec[0] gets some input inp[0], dec[1] receives some combination of tgt[0] as inp[1] and so on.

Assuming the last iteration is t, this way the last decoder cell, dec[t] receives tgt[t-1] as inp[t], so there’s no need of the last target tgt[t]. But it turns out I was wrong.

This code suggests that inp[i] is some combination of tgt[i] and not tgt[i-1]. If this is the case, why do we drop the last target?

(Nikhil Verma) #2

So it turns out that tgt[0] is the start of sequence tag <s> which is passed on to dec[0]. tgt[-1] will either be the end of sequence tag </s> or <blank>. In either case they don’t need to be feeded into the decoder again.

Aside: While computing the loss, the outputs should be compared against tgt[1:]. This is indeed the case.