Index out of range when computing batch loss with transformer architecture


(Justin Grace) #1


I am getting an index error when training a transformer model:

Traceback (most recent call last):
  File "OpenNMT-py/", line 40, in <module>
  File "OpenNMT-py/", line 27, in main
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/", line 128, in main
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/", line 175, in train
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/", line 287, in _gradient_accumulation
    trunc_size, self.shard_size, normalization)
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/utils/", line 143, in sharded_compute_loss
    loss, stats = self._compute_loss(batch, **shard)
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/modules/", line 197, in _compute_loss
    batch, self.tgt_vocab, self.cur_dataset.src_vocabs)
  File "/disk/ocean/jgrace/OpenNMT-py/onmt/inputters/", line 116, in collapse_copy_scores
    src_vocab = src_vocabs[index]
IndexError: list index out of range

It seems to be relating to the following function:

def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs):
        Given scores from an expanded dictionary
        corresponeding to a batch, sums together copies,
        with a dictionary word when it is ambigious.
        offset = len(tgt_vocab)
        for b in range(batch.batch_size):
            blank = []
            fill = []
            index =[b]
            src_vocab = src_vocabs[index]
            for i in range(1, len(src_vocab)):
                sw = src_vocab.itos[i]
                ti = tgt_vocab.stoi[sw]
                if ti != 0:
                    blank.append(offset + i)
            if blank:
                blank = torch.Tensor(blank).type_as(
                fill = torch.Tensor(fill).type_as(
                scores[:, b].index_add_(1, fill,
                                        scores[:, b].index_select(1, blank))
                scores[:, b].index_fill_(1, blank, 1e-10)
        return scores

where a vocab subset is created for computing the loss in sharded data. I cant see why this might not be working, any help would be appreciated.