This discussion started here ( https://github.com/OpenNMT/OpenNMT-py/issues/1418 ) but belongs in the forum.
I’m pretty new to pytorch, so I’d like to understand what work would be involved to implement the feature of freezing the encoder side of a trained transformer, and just continuing training on the decoder side. Also, does anyone know of a paper where this has been done successfully? The goal is to “fine-tune” a trained model on a new dataset, sort of an analogue of what’s done with pretraining on ImageNet.
In any case, it was mentioned in the issue above that one would need to “implement a detach_state() function in the encoder in the same way as it is for the decoder (used in the case of bptt)”
When I looked into pytorch weight-freezing I saw that building a model with requires_grad=False would have the effect of preventing gradients from being computed. As a pytorch noob I don’t know why one would need to use “detach” instead.
Can someone help explain why requires_grad=False doesn’t do the job here, and what this “detatch” thing is doing? I read that requires_grad=False should use less compute, so it sounds preferable; we don’t want the gradients at all if we’re only training the decoder.
Furthermore, to be clear I’m assuming we can do this:
- Train a transformer model, save weights to disk
- Create a new transformer model with the same architecture, except that the encoder parts are frozen somehow
- Load the weights for this model from disk (saved in step 1) and train on a new dataset.
Can someone help correct any of my misconceptions here so I can understand how to implement the weight-freezing correctly? Thanks!