Here is a MWE for something I am getting an error for.
There are four files, as usual for preprocess.py:
en_tra_2cls.txt contains:
g│1 o│0 o│0 r│0 t│0 y│0
w│1 h│0 i│0 t│0 t│0 e│0 r│0 i│0 d│0 g│0 e│0
en_dev_2cls.txt contains:
h│1 u│0 b│0 e│0 r│0
l│1 a│0 p│0 a│0 d│0 i│0 s│0
ch_tra_2cls.txt contains:
古│1 尔│1 蒂│0
惠│1 特│1 里│1 奇│0
ch_dev_2cls.txt contains:
休│1 伯│0
拉│1 帕│1 迪│1 斯│0
Then I run preprocess.py:
python nmt-py/OpenNMT-py/preprocess.py -train_src en_tra_2cls.txt -train_tgt ch_tra_2cls.txt -valid_src en_dev_2cls.txt -valid_tgt ch_dev_2cls.txt -save_data 2cls
all goes well.
Then I run train.py:
python nmt-py/OpenNMT-py/train.py -data 2cls -save_model 2cls -train_steps 10 -seed 7 -start_decay_step 5000 -save_checkpoint_steps 5 -keep_checkpoint 5 -decay_steps 1000 -gpuid 1
all goes well.
however, when I run:
python nmt-py/OpenNMT-py/translate.py -src en_dev_2cls.txt -model 2cls_step_5.pt -output rrr.txt -verbose -replace_unk -gpu 1
I get the error
Traceback (most recent call last):
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/translate.py”, line 36, in
main(opt)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/translate.py”, line 24, in main
attn_debug=opt.attn_debug)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/translate/translator.py”, line 213, in translate
batch_data = self.translate_batch(batch, data, fast=self.fast)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/translate/translator.py”, line 314, in translate_batch
return self._translate_batch(batch, data)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/translate/translator.py”, line 559, in _translate_batch
step=i)
File “/disk/ocean/lhe/condaPytorch/envs/pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 491, in call
result = self.forward(*input, **kwargs)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/decoders/decoder.py”, line 136, in forward
tgt, memory_bank, state, memory_lengths=memory_lengths)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/decoders/decoder.py”, line 314, in run_forward_pass
emb = self.embeddings(tgt)
File “/disk/ocean/lhe/condaPytorch/envs/pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 491, in call
result = self.forward(*input, **kwargs)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/modules/embeddings.py”, line 184, in forward
emb = self.make_embedding(source)
File “/disk/ocean/lhe/condaPytorch/envs/pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 491, in call
result = self.forward(*input, **kwargs)
File “/disk/ocean/lhe/condaPytorch/envs/pytorch/lib/python2.7/site-packages/torch/nn/modules/container.py”, line 91, in forward
input = module(input)
File “/disk/ocean/lhe/condaPytorch/envs/pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 491, in call
result = self.forward(*input, **kwargs)
File “/disk/ocean/lhe/transliteration/nmt-py/OpenNMT-py/onmt/modules/util_class.py”, line 42, in forward
assert len(self) == len(inputs)
AssertionError
any ideas why?
I am using the latest version of pytorch as of date (using “pip install”) and I believe the latest version of OpenNMT.
If I remove the “word features” (from both source and target), everything works well.