I am running into some issues during translation. During training everything works fine, but during translation the multiheadattention in the decoder throws an error about the mask size.
Traceback (most recent call last):
File "/home/bram/.local/share/virtualenvs/nfr-experiments-3R5lX5O6/bin/onmt_translate", line 11, in <module>
load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_translate')()
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/bin/translate.py", line 48, in main
translate(opt)
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/bin/translate.py", line 25, in translate
translator.translate(
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/translate/translator.py", line 361, in translate
batch_data = self.translate_batch(
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/translate/translator.py", line 550, in translate_batch
return self._translate_batch_with_strategy(batch, src_vocabs,
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/translate/translator.py", line 674, in _translate_batch_with_strategy
log_probs, attn = self._decode_and_generate(
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/translate/translator.py", line 589, in _decode_and_generate
dec_out, dec_attn = self.model.decoder(
File "/home/bram/.local/share/virtualenvs/nfr-experiments-3R5lX5O6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/decoders/transformer.py", line 319, in forward
output, attn, attn_align = layer(
File "/home/bram/.local/share/virtualenvs/nfr-experiments-3R5lX5O6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/decoders/transformer.py", line 93, in forward
output, attns = self._forward(*args, **kwargs)
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/decoders/transformer.py", line 165, in _forward
mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
File "/home/bram/.local/share/virtualenvs/nfr-experiments-3R5lX5O6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/bram/Python/projects/nfr-experiments/OpenNMT-py/onmt/modules/multi_headed_attn.py", line 202, in forward
scores = scores.masked_fill(mask, -1e18)
RuntimeError: The size of tensor a (30) must match the size of tensor b (150) at non-singleton dimension 0
After doing some logging I found that the encoder does not have any issues, but that the decoder cannot simply use the src_mask that was created in the encoder. I am not sure why. Here are the logs, and you can see that in the first decoder layer multi-head attention will fail because there is a shape mismatch in the code snippet that you linked. So the mask should be filled for the padded tokens, but the dimensions do not match.
Entering encoder layer 0
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
Entering encoder layer 1
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
Entering encoder layer 2
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
Entering encoder layer 3
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
Entering encoder layer 4
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
Entering encoder layer 5
score torch.Size([30, 8, 223, 223])
mask torch.Size([30, 1, 1, 223])
decoder_input torch.Size([1, 150, 1])
src_pad_mask torch.Size([30, 1, 223]) # here
tgt_pad_mask torch.Size([150, 1, 1])
Entering decoder layer 0
score torch.Size([150, 8, 1, 223])
mask torch.Size([30, 1, 1, 223]) # here
So it seems that the decoder expects a much larger attention mask then the encoder returns, but I do not understand why. Do you have any thoughts on this? Especially odd to me since this does work in training.
I only made a few changes to the encoder/decoder, which you can find here: https://github.com/BramVanroy/OpenNMT-py/tree/masked_tokens