Successful Domain Adaptation with OpenNMT-py

I managed to achieve successful results in retraining for the purpose of Domain Adaptation using OpenNMT-py. As this was with the help of colleagues here, I am elaborating on the path I took; hopefully, it will be useful for others.

The base model is a vertical (in-domain) model trained on approx. 13 million segments, and retrained on approx. 123,000 institution-specific segments. Language Pair: French-English. Tokenization: complete words.


First Step: Training the base model

1- Preprocessing was with the default options of OpenNMT-py.

2- Training was with the recommended Transformer model options, except that I had only 2 GPUs.

CUDA_VISIBLE_DEVICES=0,1 python3 train.py -data basedata -save_model basemodel -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -train_steps 200000 -max_generator_batches 2 -dropout 0.1 -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 -max_grad_norm 0 -param_init 0 -param_init_glorot -label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 -log_file train.log -world_size 2 -gpu_ranks 0 1 ; sudo shutdown


Second Step: Retraining with the new data

1- Preprocessing:

I passed the basedata.vocab.pt file to the parameter -src_vocab. There is no need for -tgt_vocab, but use -share_vocab as well (reference). Actually only -src_vocab supports *.vocab.pt files, and adding the file to -tgt_vocab will cause an error.

I used also -src_seq_length 200 because I have long sentences, but you can use the default (50) or whatever you need.

python3 preprocess.py -train_src newdata.fr -train_tgt newdata.en -save_data newdata -src_seq_length 200 -tgt_seq_length 200 -src_vocab basedata.vocab.pt -dynamic_dict -share_vocab -log_file preprocess-new.log

2- Retraining

I used -train_from the last step file of the base model, retraining the model for extra 10,000 steps. Note the old model was trained for 200,000 steps; so to set the extra 10,000 steps in retraining, it will be 210,000 steps because retraining uses the previous arguments unless you use the option -reset_optim

Note also that the second machine is with 8 GPUs; so with the same batch size, 10,000 steps on 8 GPUs are similar to 40,000 steps on 2 GPUs (reference). Calculating steps in the first place was tricky because the batch type here depends on tokens not sentences and there are multiple GPUs (reference), but I used the sequence length as a reference (actually half of it because not many sentences are of 200 tokens), which will not be very accurate as it is the max not an exact number, but it helps understand what one is doing. The ultimate purpose was to retrain on the new data for long enough to learn the new vocabulary (reference).

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py -data newdata -train_from basemodel_step_200000.pt -save_model newmodel -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -train_steps 210000 -max_generator_batches 2 -dropout 0.1 -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 -max_grad_norm 0 -param_init 0 -param_init_glorot -label_smoothing 0.1 -save_checkpoint_steps 10000 -log_file retrain.log -world_size 8 -gpu_ranks 0 1 2 3 4 5 6 7 ; sudo shutdown


Outcomes

When I started retraining, I was not sure if the model will only learn new vocabulary or will also replace vocabulary because it was usually said that OpenNMT-py is not the best for retraining as it does not have an update vocabulary option.

However, the outcome is very promising. The model learnt to use the institution-based terminology. Here is one simple example to get an idea: the base model translates the French words “président” and “vice-président” as “president” and “vice-president” in English respectively while the retrained model translates them as “chairperson” and “deputy chairperson” respectively, which are the adopted English terms in the institution.


Further Research

The issue I noticed though is that some sentences were translated badly (like unidiomatic structure or UNKs) by the retrained model while they are translated better by the base model. I am not sure why, and I wonder if this could be because of an exaggerated number of re-training steps; so I have to test this. Still, I believe as a workaround, I can offer translations from the two models and let the user select, or automatically select the best translation based on automatic evaluation.

So that is it. If you have questions or suggestions, please let me know.

Best regards,
Yasmin

4 Likes

Nice post!

This may be a case of “catastrophic forgetting”. Usually the retraining should be a combination of in-domain and generic data.

1 Like

Thanks, Guillaume!

Ah! I will try this. Many thanks!

Kind regards,
Yasmin

could you please elaborate this point @guillaumekln

Generally, if you train a model on a task A and then continue training on a task B, the model will progressively forget task A as the time passes. That is why you probably you want to use data from task A and B (here generic and in-domain) so that it can discover the new task B while still consolidating its representation of task A.

1 Like

One thing to try (whether it is training with two corpora or retraining), is giving the in-domain corpus a higher weight than the generic corpus. OpenNMT-py supports different weighting of 2 corpora. (Instructions)


Another method to try for Domain Adaptation is adopted by this paper, Fast Domain Adaptation for Neural Machine Translation, which is to “ensemble” the baseline model trained on generic data and the continue model retrained on in-domain data. “Ensemble” simply means combining models during translation (not data during training). The paper did not mention though which ensemble technique was used (unless I miss something) as there are several methods for ensembling.

Anyhow, I checked OpenNMT-py and it supports ensemble decoding. (Reference)

I have also checked OpenNMT-py code and it has an ensemble.py file, that starts with the following:

Ensemble decoding.
Decodes using multiple models simultaneously,
combining their prediction distributions by averaging.
All models in the ensemble must share a target vocabulary.

I tried ensemble decoding in OpenNMT-py with my two models (stated in the original post) and got this error: “AssertionError: Ensemble models must use the same preprocessed data”. Apparently, I need to apply something like this solution, or I can continue training with the new dataset without creating new vocabulary, i.e. using the old vocabulary file of the original dataset, as in my case the original dataset most likely has all the vocabulary, but it is only about terminology choice in the specialized dataset.

I will send updates when I have any.

Kind regards,
Yasmin

1 Like

Hi Yasmin,
Yes for ensembling you need to have the same vocab file for both model.
When using BPE/subword it should not really be an issue but ideally it is better to train your BPE then vocab on the combined data.
Good luck.

Many thanks, Vincent, for information and advice; I will try this in my next model.

Have a great weekend!

Kind regards,
Yasmin

Hi,
Any ideas why I get this BLEU results when continuing training with mixed out domain and in domain data? Out domain data is a small subset of the original out domain data to avoid catastrophic forgetting.

image

It seems the best result is obtained directly by the first checkpoint after just 10000 training steps from the previous out domain best model (24000 steps). Is this the result of having enough information to adapt the model in just 10000 steps and there is no need for further training?

Dear Anderleich,

This makes perfect sense because you do not want to train the model too much on the small data; you rather want to keep the output obtained from the large (generic) data to hopefully get translations that both observe generic parts of text and in-domain vocabulary.

Kind regards,
Yasmin

Thanks @ymoslem !
Which would be the optimal percentage of new data when mixing out and in domain data?

Thanks to everyone who took part here! Just would like to update the discussion with this useful paper about “mixed fine-tuning”.