How computing loss in shards helps reduce memory cost?

pytorch

#1

In OpenNMT-py, loss is computed in shards, which is computing loss in truncation and twice:
Once from true loss to output layer, once from output layer to the whole network.
And default truncation size is 32.

May I ask how could this approach saves memory, and is it also faster?

Thanks in advance.