I managed to achieve successful results in retraining for the purpose of Domain Adaptation using OpenNMT-py. As this was with the help of colleagues here, I am elaborating on the path I took; hopefully, it will be useful for others.
The base model is a vertical (in-domain) model trained on approx. 13 million segments, and retrained on approx. 123,000 institution-specific segments. Language Pair: French-English. Tokenization: complete words.
First Step: Training the base model
1- Preprocessing was with the default options of OpenNMT-py.
2- Training was with the recommended Transformer model options, except that I had only 2 GPUs.
CUDA_VISIBLE_DEVICES=0,1 python3 train.py -data basedata -save_model basemodel -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -train_steps 200000 -max_generator_batches 2 -dropout 0.1 -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 -max_grad_norm 0 -param_init 0 -param_init_glorot -label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 -log_file train.log -world_size 2 -gpu_ranks 0 1 ; sudo shutdown
Second Step: Retraining with the new data
1- Preprocessing:
I passed the basedata.vocab.pt file to the parameter -src_vocab
. There is no need for -tgt_vocab
, but use -share_vocab
as well (reference). Actually only -src_vocab
supports *.vocab.pt files, and adding the file to -tgt_vocab
will cause an error.
I used also -src_seq_length 200
because I have long sentences, but you can use the default (50) or whatever you need.
python3 preprocess.py -train_src newdata.fr -train_tgt newdata.en -save_data newdata -src_seq_length 200 -tgt_seq_length 200 -src_vocab basedata.vocab.pt -dynamic_dict -share_vocab -log_file preprocess-new.log
2- Retraining
I used -train_from
the last step file of the base model, retraining the model for extra 10,000 steps. Note the old model was trained for 200,000 steps; so to set the extra 10,000 steps in retraining, it will be 210,000 steps because retraining uses the previous arguments unless you use the option -reset_optim
Note also that the second machine is with 8 GPUs; so with the same batch size, 10,000 steps on 8 GPUs are similar to 40,000 steps on 2 GPUs (reference). Calculating steps in the first place was tricky because the batch type here depends on tokens not sentences and there are multiple GPUs (reference), but I used the sequence length as a reference (actually half of it because not many sentences are of 200 tokens), which will not be very accurate as it is the max not an exact number, but it helps understand what one is doing. The ultimate purpose was to retrain on the new data for long enough to learn the new vocabulary (reference).
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py -data newdata -train_from basemodel_step_200000.pt -save_model newmodel -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -train_steps 210000 -max_generator_batches 2 -dropout 0.1 -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 -max_grad_norm 0 -param_init 0 -param_init_glorot -label_smoothing 0.1 -save_checkpoint_steps 10000 -log_file retrain.log -world_size 8 -gpu_ranks 0 1 2 3 4 5 6 7 ; sudo shutdown
Outcomes
When I started retraining, I was not sure if the model will only learn new vocabulary or will also replace vocabulary because it was usually said that OpenNMT-py is not the best for retraining as it does not have an update vocabulary option.
However, the outcome is very promising. The model learnt to use the institution-based terminology. Here is one simple example to get an idea: the base model translates the French words “président” and “vice-président” as “president” and “vice-president” in English respectively while the retrained model translates them as “chairperson” and “deputy chairperson” respectively, which are the adopted English terms in the institution.
Further Research
The issue I noticed though is that some sentences were translated badly (like unidiomatic structure or UNKs) by the retrained model while they are translated better by the base model. I am not sure why, and I wonder if this could be because of an exaggerated number of re-training steps; so I have to test this. Still, I believe as a workaround, I can offer translations from the two models and let the user select, or automatically select the best translation based on automatic evaluation.
So that is it. If you have questions or suggestions, please let me know.
Best regards,
Yasmin