OpenNMT Forum

Quantization for training

Hi everyone,

I know that the CTranslate2 supports int16 and int8; but is there fixed-point support for the training process?


Not an expert on quantization, but I’m not sure that’s a thing.
“Quantization aware training” exists, but it’s still trained in ‘normal’ precision.

PyTorch supports multiple approaches to quantizing a deep learning model. In most cases the model is trained in FP32 and then the model is converted to INT8. In addition, PyTorch also supports quantization aware training, which models quantization errors in both the forward and backward passes using fake-quantization modules. Note that the entire computation is carried out in floating point. At the end of quantization aware training, PyTorch provides conversion functions to convert the trained model into lower precision.

(from the torch docs: Quantization — PyTorch 1.7.0 documentation)

Thanks a lot for the quick reply.
I was just wondering if it is supported as part of the training script of opennmt, e.g. setting -model_dtype to int8.

No it is not. Feel free to contribute if you feel like it!