I found that in order to use custom loss you have to implement _make_shard_state function
tgt_vocab (:obj:`Vocab`) :
torchtext vocab object representing the target output
normalzation (str): normalize by "sents" or "tokens"
"""
def __init__(self, generator, tgt_vocab):
super(LossComputeBase, self).__init__()
self.generator = generator
self.tgt_vocab = tgt_vocab
self.padding_idx = tgt_vocab.stoi[onmt.io.PAD_WORD]
def _make_shard_state(self, batch, output, range_, attns=None):
"""
Make shard state dictionary for shards() to return iterable
shards for efficient loss computation. Subclass must define
this method to match its own _compute_loss() interface.
Args:
batch: the current batch.
output: the predict output from the model.
range_: the range of examples for computing, the whole
batch or a trunc of it?
attns: the attns dictionary returned from the model.
But I didn’t get the idea of shards and what exactly should I write in the “_make_shard_state()” function.
Any explanation of “shards” would be helpful. Thank you.