Question about the implementation of decoding


(Rylanchiu) #1

In, only the last predicted tokens are feed in the decoder at each step of decoding. This may work in RNN-based decoder. But in the transformer, as far as I know, the decoder need to attend to all the previous target-side tokens in decoding period. So which part of the code can tackle this issue? Please tell me if I had any misunderstanding. Thanks in advance!

(Guillaume Klein) #2

Previous timesteps are cached in the TransformerDecoder instance. There are 2 modes:

  • cache the all previous inputs and concatenate them before running the self attention layer
  • cache the all previous projected keys and values and concatenate them before running the dot product attention

The second mode was recently added and is of course faster. See:

(Rylanchiu) #3

Thanks for your reply. But I still have the problem of the first model. Could you please point me to the code that cache the previous inputs? Seems that I could not find them. As far as I know, call TransformerDecoder() once equals to one decoding step. But in the, I could not find any value is passed to cache .

(Guillaume Klein) #4

Read cached inputs of each layer:

Write all layers input in the cache: