Transformer: Freezing Encoder while training Decoder

This discussion started here ( https://github.com/OpenNMT/OpenNMT-py/issues/1418 ) but belongs in the forum.

I’m pretty new to pytorch, so I’d like to understand what work would be involved to implement the feature of freezing the encoder side of a trained transformer, and just continuing training on the decoder side. Also, does anyone know of a paper where this has been done successfully? The goal is to “fine-tune” a trained model on a new dataset, sort of an analogue of what’s done with pretraining on ImageNet.

In any case, it was mentioned in the issue above that one would need to “implement a detach_state() function in the encoder in the same way as it is for the decoder (used in the case of bptt)”

When I looked into pytorch weight-freezing I saw that building a model with requires_grad=False would have the effect of preventing gradients from being computed. As a pytorch noob I don’t know why one would need to use “detach” instead.

Can someone help explain why requires_grad=False doesn’t do the job here, and what this “detatch” thing is doing? I read that requires_grad=False should use less compute, so it sounds preferable; we don’t want the gradients at all if we’re only training the decoder.

Furthermore, to be clear I’m assuming we can do this:

  1. Train a transformer model, save weights to disk
  2. Create a new transformer model with the same architecture, except that the encoder parts are frozen somehow
  3. Load the weights for this model from disk (saved in step 1) and train on a new dataset.

Can someone help correct any of my misconceptions here so I can understand how to implement the weight-freezing correctly? Thanks!

Did you find this one?

https://www.aclweb.org/anthology/W18-6313

Thanks. I was most interested in a paper exploring this with the Transformer model in particular, but this paper is useful as well, as one can imagine that similar results may hold for the Transformer.

@eraoul have you been able to do this experiment?