Finetuning bigger models with LoRa (Low-Rank Adaptation) in OpenNMT-py

I got a huge memory reduction by using fusedadam and LoRA. I fine-tuned a 3.3B model on a 32 GB GPU. I could not do this with Adam or without LoRA, I needed both.

The loaded model alone goes from 13.664 MiB to 7.830 MiB when I use LoRA and Fused Adam.

I had to install apex by following the GitHub guide and modifying the pip install like this:

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--deprecated_fused_adam" --global-option="--cpp_ext" --global-option="--cuda_ext"  ./

If you do not add

--global-option="--deprecated_fused_adam" 

it would not work with OpenNMT-py. I installed OpenNMT from master this morning.

I did a previous experiment with 600M model where I went from 17.070 MiB with just Adam to 10.460 MiB during training when I changed to LoRA+Adam. Just the loaded model goes from 3.210 MiB with just Adam to 3.214 MiB when I changed to LoRA+Adam. With fuse Adam it goes from 2.020 MiB with just Adam to 8.682 MiB when I changed to LoRA+Adam.

I have been using this LoRA options:

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linnear_keys', 'final_linear']
lora_rank: 4
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false

Martin,

I am still working on a PR because there is a bug in the original LoRa implementation to merge weights.
I should be done tonight.
I am testing everything with Llama 7B which should be the same procedure.

The training is OK, but I can’t inference with it, or do you think I need to repeat the training?

no but you will have to run the modified script of lora_weights.py to merge or concat.

Martin,

I just committed.
If your last training went through you should have:

  1. your base model (in the OpenNMT-py format) before finetuning

  2. your Lora Weights (much smaller file)

If you git pull master you should be able to run:

python tools/lora_weights.py --action merge --base_model nllb-onmt.pt --lora_weights your_lora.pt --output nllb-finetuned.pt

Then you can try to run inference with the new nllb-finetuned.pt

Let me know how it goes.

After doing this, I get the following error when I run inference with the new nllb-finetuned.pt:

Traceback (most recent call last):
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/bin/onmt_translate", line 33, in <module>
    sys.exit(load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_translate')())
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 60, in main
    translate(opt)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 23, in translate
    translator = build_translator(opt, logger=logger,
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/translate/translator.py", line 33, in build_translator
    vocabs, model, model_opt = load_test_model(opt)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 171, in load_test_model
    model = build_base_model(model_opt, vocabs, checkpoint)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 402, in build_base_model
    model.load_state_dict(checkpoint['model'],
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NMTModel:
	Missing key(s) in state_dict: "encoder.transformer.0.self_attn.linear_keys.bias", "encoder.transformer.0.self_attn.linear_values.bias", "encoder.transformer.0.self_attn.linear_query.bias", "encoder.transformer.0.self_attn.final_linear.bias", "encoder.transformer.1.self_attn.linear_keys.bias", "encoder.transformer.1.self_attn.linear_values.bias", "encoder.transformer.1.self_attn.linear_query.bias", "encoder.transformer.1.self_attn.final_linear.bias", "encoder.transformer.2.self_attn.linear_keys.bias", "encoder.transformer.2.self_attn.linear_values.bias", "encoder.transformer.2.self_attn.linear_query.bias", "encoder.transformer.2.self_attn.final_linear.bias", "encoder.transformer.3.self_attn.linear_keys.bias", "encoder.transformer.3.self_attn.linear_values.bias", "encoder.transformer.3.self_attn.linear_query.bias", "encoder.transformer.3.self_attn.final_linear.bias", "encoder.transformer.4.self_attn.linear_keys.bias", "encoder.transformer.4.self_attn.linear_values.bias", "encoder.transformer.4.self_attn.linear_query.bias", "encoder.transformer.4.self_attn.final_linear.bias", "encoder.transformer.5.self_attn.linear_keys.bias", "encoder.transformer.5.self_attn.linear_values.bias", "encoder.transformer.5.self_attn.linear_query.bias", "encoder.transformer.5.self_attn.final_linear.bias", "encoder.transformer.6.self_attn.linear_keys.bias", "encoder.transformer.6.self_attn.linear_values.bias", "encoder.transformer.6.self_attn.linear_query.bias", "encoder.transformer.6.self_attn.final_linear.bias", "encoder.transformer.7.self_attn.linear_keys.bias", "encoder.transformer.7.self_attn.linear_values.bias", "encoder.transformer.7.self_attn.linear_query.bias", "encoder.transformer.7.self_attn.final_linear.bias", "encoder.transformer.8.self_attn.linear_keys.bias", "encoder.transformer.8.self_attn.linear_values.bias", "encoder.transformer.8.self_attn.linear_query.bias", "encoder.transformer.8.self_attn.final_linear.bias", "encoder.transformer.9.self_attn.linear_keys.bias", "encoder.transformer.9.self_attn.linear_values.bias", "encoder.transformer.9.self_attn.linear_query.bias", "encoder.transformer.9.self_attn.final_linear.bias", "encoder.transformer.10.self_attn.linear_keys.bias", "encoder.transformer.10.self_attn.linear_values.bias", "encoder.transformer.10.self_attn.linear_query.bias", "encoder.transformer.10.self_attn.final_linear.bias", "encoder.transformer.11.self_attn.linear_keys.bias", "encoder.transformer.11.self_attn.linear_values.bias", "encoder.transformer.11.self_attn.linear_query.bias", "encoder.transformer.11.self_attn.final_linear.bias", "encoder.transformer.12.self_attn.linear_keys.bias", "encoder.transformer.12.self_attn.linear_values.bias", "encoder.transformer.12.self_attn.linear_query.bias", "encoder.transformer.12.self_attn.final_linear.bias", "encoder.transformer.13.self_attn.linear_keys.bias", "encoder.transformer.13.self_attn.linear_values.bias", "encoder.transformer.13.self_attn.linear_query.bias", "encoder.transformer.13.self_attn.final_linear.bias", "encoder.transformer.14.self_attn.linear_keys.bias", "encoder.transformer.14.self_attn.linear_values.bias", "encoder.transformer.14.self_attn.linear_query.bias", "encoder.transformer.14.self_attn.final_linear.bias", "encoder.transformer.15.self_attn.linear_keys.bias", "encoder.transformer.15.self_attn.linear_values.bias", "encoder.transformer.15.self_attn.linear_query.bias", "encoder.transformer.15.self_attn.final_linear.bias", "encoder.transformer.16.self_attn.linear_keys.bias", "encoder.transformer.16.self_attn.linear_values.bias", "encoder.transformer.16.self_attn.linear_query.bias", "encoder.transformer.16.self_attn.final_linear.bias", "encoder.transformer.17.self_attn.linear_keys.bias", "encoder.transformer.17.self_attn.linear_values.bias", "encoder.transformer.17.self_attn.linear_query.bias", "encoder.transformer.17.self_attn.final_linear.bias", "encoder.transformer.18.self_attn.linear_keys.bias", "encoder.transformer.18.self_attn.linear_values.bias", "encoder.transformer.18.self_attn.linear_query.bias", "encoder.transformer.18.self_attn.final_linear.bias", "encoder.transformer.19.self_attn.linear_keys.bias", "encoder.transformer.19.self_attn.linear_values.bias", "encoder.transformer.19.self_attn.linear_query.bias", "encoder.transformer.19.self_attn.final_linear.bias", "encoder.transformer.20.self_attn.linear_keys.bias", "encoder.transformer.20.self_attn.linear_values.bias", "encoder.transformer.20.self_attn.linear_query.bias", "encoder.transformer.20.self_attn.final_linear.bias", "encoder.transformer.21.self_attn.linear_keys.bias", "encoder.transformer.21.self_attn.linear_values.bias", "encoder.transformer.21.self_attn.linear_query.bias", "encoder.transformer.21.self_attn.final_linear.bias", "encoder.transformer.22.self_attn.linear_keys.bias", "encoder.transformer.22.self_attn.linear_values.bias", "encoder.transformer.22.self_attn.linear_query.bias", "encoder.transformer.22.self_attn.final_linear.bias", "encoder.transformer.23.self_attn.linear_keys.bias", "encoder.transformer.23.self_attn.linear_values.bias", "encoder.transformer.23.self_attn.linear_query.bias", "encoder.transformer.23.self_attn.final_linear.bias", "decoder.transformer_layers.0.self_attn.linear_keys.bias", "decoder.transformer_layers.0.self_attn.linear_values.bias", "decoder.transformer_layers.0.self_attn.linear_query.bias", "decoder.transformer_layers.0.self_attn.final_linear.bias", "decoder.transformer_layers.0.context_attn.linear_keys.bias", "decoder.transformer_layers.0.context_attn.linear_values.bias", "decoder.transformer_layers.0.context_attn.linear_query.bias", "decoder.transformer_layers.0.context_attn.final_linear.bias", "decoder.transformer_layers.1.self_attn.linear_keys.bias", "decoder.transformer_layers.1.self_attn.linear_values.bias", "decoder.transformer_layers.1.self_attn.linear_query.bias", "decoder.transformer_layers.1.self_attn.final_linear.bias", "decoder.transformer_layers.1.context_attn.linear_keys.bias", "decoder.transformer_layers.1.context_attn.linear_values.bias", "decoder.transformer_layers.1.context_attn.linear_query.bias", "decoder.transformer_layers.1.context_attn.final_linear.bias", "decoder.transformer_layers.2.self_attn.linear_keys.bias", "decoder.transformer_layers.2.self_attn.linear_values.bias", "decoder.transformer_layers.2.self_attn.linear_query.bias", "decoder.transformer_layers.2.self_attn.final_linear.bias", "decoder.transformer_layers.2.context_attn.linear_keys.bias", "decoder.transformer_layers.2.context_attn.linear_values.bias", "decoder.transformer_layers.2.context_attn.linear_query.bias", "decoder.transformer_layers.2.context_attn.final_linear.bias", "decoder.transformer_layers.3.self_attn.linear_keys.bias", "decoder.transformer_layers.3.self_attn.linear_values.bias", "decoder.transformer_layers.3.self_attn.linear_query.bias", "decoder.transformer_layers.3.self_attn.final_linear.bias", "decoder.transformer_layers.3.context_attn.linear_keys.bias", "decoder.transformer_layers.3.context_attn.linear_values.bias", "decoder.transformer_layers.3.context_attn.linear_query.bias", "decoder.transformer_layers.3.context_attn.final_linear.bias", "decoder.transformer_layers.4.self_attn.linear_keys.bias", "decoder.transformer_layers.4.self_attn.linear_values.bias", "decoder.transformer_layers.4.self_attn.linear_query.bias", "decoder.transformer_layers.4.self_attn.final_linear.bias", "decoder.transformer_layers.4.context_attn.linear_keys.bias", "decoder.transformer_layers.4.context_attn.linear_values.bias", "decoder.transformer_layers.4.context_attn.linear_query.bias", "decoder.transformer_layers.4.context_attn.final_linear.bias", "decoder.transformer_layers.5.self_attn.linear_keys.bias", "decoder.transformer_layers.5.self_attn.linear_values.bias", "decoder.transformer_layers.5.self_attn.linear_query.bias", "decoder.transformer_layers.5.self_attn.final_linear.bias", "decoder.transformer_layers.5.context_attn.linear_keys.bias", "decoder.transformer_layers.5.context_attn.linear_values.bias", "decoder.transformer_layers.5.context_attn.linear_query.bias", "decoder.transformer_layers.5.context_attn.final_linear.bias", "decoder.transformer_layers.6.self_attn.linear_keys.bias", "decoder.transformer_layers.6.self_attn.linear_values.bias", "decoder.transformer_layers.6.self_attn.linear_query.bias", "decoder.transformer_layers.6.self_attn.final_linear.bias", "decoder.transformer_layers.6.context_attn.linear_keys.bias", "decoder.transformer_layers.6.context_attn.linear_values.bias", "decoder.transformer_layers.6.context_attn.linear_query.bias", "decoder.transformer_layers.6.context_attn.final_linear.bias", "decoder.transformer_layers.7.self_attn.linear_keys.bias", "decoder.transformer_layers.7.self_attn.linear_values.bias", "decoder.transformer_layers.7.self_attn.linear_query.bias", "decoder.transformer_layers.7.self_attn.final_linear.bias", "decoder.transformer_layers.7.context_attn.linear_keys.bias", "decoder.transformer_layers.7.context_attn.linear_values.bias", "decoder.transformer_layers.7.context_attn.linear_query.bias", "decoder.transformer_layers.7.context_attn.final_linear.bias", "decoder.transformer_layers.8.self_attn.linear_keys.bias", "decoder.transformer_layers.8.self_attn.linear_values.bias", "decoder.transformer_layers.8.self_attn.linear_query.bias", "decoder.transformer_layers.8.self_attn.final_linear.bias", "decoder.transformer_layers.8.context_attn.linear_keys.bias", "decoder.transformer_layers.8.context_attn.linear_values.bias", "decoder.transformer_layers.8.context_attn.linear_query.bias", "decoder.transformer_layers.8.context_attn.final_linear.bias", "decoder.transformer_layers.9.self_attn.linear_keys.bias", "decoder.transformer_layers.9.self_attn.linear_values.bias", "decoder.transformer_layers.9.self_attn.linear_query.bias", "decoder.transformer_layers.9.self_attn.final_linear.bias", "decoder.transformer_layers.9.context_attn.linear_keys.bias", "decoder.transformer_layers.9.context_attn.linear_values.bias", "decoder.transformer_layers.9.context_attn.linear_query.bias", "decoder.transformer_layers.9.context_attn.final_linear.bias", "decoder.transformer_layers.10.self_attn.linear_keys.bias", "decoder.transformer_layers.10.self_attn.linear_values.bias", "decoder.transformer_layers.10.self_attn.linear_query.bias", "decoder.transformer_layers.10.self_attn.final_linear.bias", "decoder.transformer_layers.10.context_attn.linear_keys.bias", "decoder.transformer_layers.10.context_attn.linear_values.bias", "decoder.transformer_layers.10.context_attn.linear_query.bias", "decoder.transformer_layers.10.context_attn.final_linear.bias", "decoder.transformer_layers.11.self_attn.linear_keys.bias", "decoder.transformer_layers.11.self_attn.linear_values.bias", "decoder.transformer_layers.11.self_attn.linear_query.bias", "decoder.transformer_layers.11.self_attn.final_linear.bias", "decoder.transformer_layers.11.context_attn.linear_keys.bias", "decoder.transformer_layers.11.context_attn.linear_values.bias", "decoder.transformer_layers.11.context_attn.linear_query.bias", "decoder.transformer_layers.11.context_attn.final_linear.bias", "decoder.transformer_layers.12.self_attn.linear_keys.bias", "decoder.transformer_layers.12.self_attn.linear_values.bias", "decoder.transformer_layers.12.self_attn.linear_query.bias", "decoder.transformer_layers.12.self_attn.final_linear.bias", "decoder.transformer_layers.12.context_attn.linear_keys.bias", "decoder.transformer_layers.12.context_attn.linear_values.bias", "decoder.transformer_layers.12.context_attn.linear_query.bias", "decoder.transformer_layers.12.context_attn.final_linear.bias", "decoder.transformer_layers.13.self_attn.linear_keys.bias", "decoder.transformer_layers.13.self_attn.linear_values.bias", "decoder.transformer_layers.13.self_attn.linear_query.bias", "decoder.transformer_layers.13.self_attn.final_linear.bias", "decoder.transformer_layers.13.context_attn.linear_keys.bias", "decoder.transformer_layers.13.context_attn.linear_values.bias", "decoder.transformer_layers.13.context_attn.linear_query.bias", "decoder.transformer_layers.13.context_attn.final_linear.bias", "decoder.transformer_layers.14.self_attn.linear_keys.bias", "decoder.transformer_layers.14.self_attn.linear_values.bias", "decoder.transformer_layers.14.self_attn.linear_query.bias", "decoder.transformer_layers.14.self_attn.final_linear.bias", "decoder.transformer_layers.14.context_attn.linear_keys.bias", "decoder.transformer_layers.14.context_attn.linear_values.bias", "decoder.transformer_layers.14.context_attn.linear_query.bias", "decoder.transformer_layers.14.context_attn.final_linear.bias", "decoder.transformer_layers.15.self_attn.linear_keys.bias", "decoder.transformer_layers.15.self_attn.linear_values.bias", "decoder.transformer_layers.15.self_attn.linear_query.bias", "decoder.transformer_layers.15.self_attn.final_linear.bias", "decoder.transformer_layers.15.context_attn.linear_keys.bias", "decoder.transformer_layers.15.context_attn.linear_values.bias", "decoder.transformer_layers.15.context_attn.linear_query.bias", "decoder.transformer_layers.15.context_attn.final_linear.bias", "decoder.transformer_layers.16.self_attn.linear_keys.bias", "decoder.transformer_layers.16.self_attn.linear_values.bias", "decoder.transformer_layers.16.self_attn.linear_query.bias", "decoder.transformer_layers.16.self_attn.final_linear.bias", "decoder.transformer_layers.16.context_attn.linear_keys.bias", "decoder.transformer_layers.16.context_attn.linear_values.bias", "decoder.transformer_layers.16.context_attn.linear_query.bias", "decoder.transformer_layers.16.context_attn.final_linear.bias", "decoder.transformer_layers.17.self_attn.linear_keys.bias", "decoder.transformer_layers.17.self_attn.linear_values.bias", "decoder.transformer_layers.17.self_attn.linear_query.bias", "decoder.transformer_layers.17.self_attn.final_linear.bias", "decoder.transformer_layers.17.context_attn.linear_keys.bias", "decoder.transformer_layers.17.context_attn.linear_values.bias", "decoder.transformer_layers.17.context_attn.linear_query.bias", "decoder.transformer_layers.17.context_attn.final_linear.bias", "decoder.transformer_layers.18.self_attn.linear_keys.bias", "decoder.transformer_layers.18.self_attn.linear_values.bias", "decoder.transformer_layers.18.self_attn.linear_query.bias", "decoder.transformer_layers.18.self_attn.final_linear.bias", "decoder.transformer_layers.18.context_attn.linear_keys.bias", "decoder.transformer_layers.18.context_attn.linear_values.bias", "decoder.transformer_layers.18.context_attn.linear_query.bias", "decoder.transformer_layers.18.context_attn.final_linear.bias", "decoder.transformer_layers.19.self_attn.linear_keys.bias", "decoder.transformer_layers.19.self_attn.linear_values.bias", "decoder.transformer_layers.19.self_attn.linear_query.bias", "decoder.transformer_layers.19.self_attn.final_linear.bias", "decoder.transformer_layers.19.context_attn.linear_keys.bias", "decoder.transformer_layers.19.context_attn.linear_values.bias", "decoder.transformer_layers.19.context_attn.linear_query.bias", "decoder.transformer_layers.19.context_attn.final_linear.bias", "decoder.transformer_layers.20.self_attn.linear_keys.bias", "decoder.transformer_layers.20.self_attn.linear_values.bias", "decoder.transformer_layers.20.self_attn.linear_query.bias", "decoder.transformer_layers.20.self_attn.final_linear.bias", "decoder.transformer_layers.20.context_attn.linear_keys.bias", "decoder.transformer_layers.20.context_attn.linear_values.bias", "decoder.transformer_layers.20.context_attn.linear_query.bias", "decoder.transformer_layers.20.context_attn.final_linear.bias", "decoder.transformer_layers.21.self_attn.linear_keys.bias", "decoder.transformer_layers.21.self_attn.linear_values.bias", "decoder.transformer_layers.21.self_attn.linear_query.bias", "decoder.transformer_layers.21.self_attn.final_linear.bias", "decoder.transformer_layers.21.context_attn.linear_keys.bias", "decoder.transformer_layers.21.context_attn.linear_values.bias", "decoder.transformer_layers.21.context_attn.linear_query.bias", "decoder.transformer_layers.21.context_attn.final_linear.bias", "decoder.transformer_layers.22.self_attn.linear_keys.bias", "decoder.transformer_layers.22.self_attn.linear_values.bias", "decoder.transformer_layers.22.self_attn.linear_query.bias", "decoder.transformer_layers.22.self_attn.final_linear.bias", "decoder.transformer_layers.22.context_attn.linear_keys.bias", "decoder.transformer_layers.22.context_attn.linear_values.bias", "decoder.transformer_layers.22.context_attn.linear_query.bias", "decoder.transformer_layers.22.context_attn.final_linear.bias", "decoder.transformer_layers.23.self_attn.linear_keys.bias", "decoder.transformer_layers.23.self_attn.linear_values.bias", "decoder.transformer_layers.23.self_attn.linear_query.bias", "decoder.transformer_layers.23.self_attn.final_linear.bias", "decoder.transformer_layers.23.context_attn.linear_keys.bias", "decoder.transformer_layers.23.context_attn.linear_values.bias", "decoder.transformer_layers.23.context_attn.linear_query.bias", "decoder.transformer_layers.23.context_attn.final_linear.bias". 

I am using the following inference.yaml:

batch_size: 8192
batch_type: tokens
beam_size: 5
fp16: null
gpu: 0
log_file: translate.log
max_length: 512
model: test_3_3B/nllb-200-lora-3_3B_step_10200.pt
report_time: true
src_prefix: </s> eng_Latn
src_subword_alpha: 0.0
src_subword_model: flores200_sacrebleu_tokenizer_spm.model
src_subword_nbest: 1
src_suffix: ''
tgt_file_prefix: true
tgt_prefix: spa_Latn
tgt_subword_alpha: 0.0
tgt_subword_model: flores200_sacrebleu_tokenizer_spm.model
tgt_subword_nbest: 1
tgt_suffix: ''
transforms:
- sentencepiece
- prefix
- suffix

I installed master 1 hour ago:
-e git+https://github.com/OpenNMT/OpenNMT-py.git@07534c5b9a181d24165ab218fda986e1fff0fef4#egg=OpenNMT_py

can you post your finetuning config ?

maybe you had override_opts: true but did not have add_qkvbias.

looking above I think this is the issue.

if you want to try to infer without retraining do the following:

Open a Python (or better ipython) console.

import torch
m = torch.load("test_3_3B/nllb-200-lora-3_3B_step_10200.pt")
m['opt'].add_qkvbias=False
torch.save(m, "test_3_3B/nllb-200-lora-3_3B_step_10200.pt")

Then your model will be consistent (no QKV bias and option set to false)

you can try inference.

It was the case, yes.

This fixed my model, now it works.

I got 1.2 BLEU more with the first experiment with 3.3B and this setup than with my best setup with 1.3B model using same GPU memory.

great.
bear in mind that your finetuning may not have been optimal without QKV biases because the original model had those biases. you may try to redo it with add_qkvbias=True

also you can try to convert your model wit CT2 and quantize to int8 (cf doc) to check if this is much faster.

I willa lso add an option to load the model in int8 mode in OpenNMT-py (with bnb) but it will not accelerate, just fit in memory for smaller GPUs.

After add it in the training and do the merge I get the following error.

Traceback (most recent call last):                                                                                                                                                                         
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/bin/onmt_translate", line 33, in <module>                                                                                                                 
    sys.exit(load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_translate')())                                                                                                                        
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 60, in main                                                                                                                   
    translate(opt)                                                                                                                                                                                         
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 23, in translate                                                                                                              
    translator = build_translator(opt, logger=logger,                                                                                                                                                      
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/translate/translator.py", line 33, in build_translator                                                                                                
    vocabs, model, model_opt = load_test_model(opt)                                                                                                                                                        
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 171, in load_test_model                                                                                                       
    model = build_base_model(model_opt, vocabs, checkpoint)                                                                                                                                                
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 402, in build_base_model                                                                                                      
    model.load_state_dict(checkpoint['model'],                                                                                                                                                             
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict                                                                   
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                                                                                                                              
RuntimeError: Error(s) in loading state_dict for NMTModel:                                                                                                                                                 
        Missing key(s) in state_dict: "encoder.transformer.0.self_attn.linear_values.bias", "encoder.transformer.0.self_attn.linear_query.bias", "encoder.transformer.0.self_attn.final_linear.bias", "enco
der.transformer.1.self_attn.linear_values.bias", "encoder.transformer.1.self_attn.linear_query.bias", "encoder.transformer.1.self_attn.final_linear.bias", "encoder.transformer.2.self_attn.linear_values.b
ias", "encoder.transformer.2.self_attn.linear_query.bias", "encoder.transformer.2.self_attn.final_linear.bias", "encoder.transformer.3.self_attn.linear_values.bias", "encoder.transformer.3.self_attn.line
ar_query.bias", "encoder.transformer.3.self_attn.final_linear.bias", "encoder.transformer.4.self_attn.linear_values.bias", "encoder.transformer.4.self_attn.linear_query.bias", "encoder.transformer.4.self
_attn.final_linear.bias", "encoder.transformer.5.self_attn.linear_values.bias"...

My trainin yaml is this one:

# Vocab creation options
share_vocab: true

## Where the vocab(s) is
src_vocab: "dictionary.txt"
src_words_min_frequency: 1
src_vocab_size: 256206

tgt_vocab: "dictionary.txt"
tgt_words_min_frequency: 1
tgt_vocab_size: 256206

vocab_size_multiple: 1

decoder_start_token: '</s>'


### Transform related opts:

#### Subword
src_subword_model: "flores200_sacrebleu_tokenizer_spm.model"
tgt_subword_model: "flores200_sacrebleu_tokenizer_spm.model"
src_subword_nbest: 1
src_subword_alpha: 0.0
tgt_subword_nbest: 1
tgt_subword_alpha: 0.0
  


#### Filter
src_seq_length: 150
tgt_seq_length: 150

# silently ignore empty lines in the data
#skip_empty_level: silent

# General opts
update_vocab: true # option to perform vocabulary update

train_from: "nllb-200-3.3B-onmt.pt"           # 3.3B

reset_optim: all # states option to perform vocabulary update
save_data: "/nllb-200"
save_model: "trained_models_3_3B_es_en_test/nllb-200-600M-onmt-1"
log_file: "train.log"

keep_checkpoint: -1
save_checkpoint_steps: 100 # 200 500

average_decay: 0.0005
seed: 1234
report_every: 1
train_steps: 10200 # 2000 100000
valid_steps: 600 # 400 5000

# Batching
bucket_size: 262144
num_workers: 1
prefetch_factor:  400
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"


batch_size: 1024                              # 3.3B
valid_batch_size: 1024                        # 3.3B
batch_size_multiple: 2                        # 3.3B

accum_count: [2]
accum_steps: [0]

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linnear_keys', 'final_linear']
lora_rank: 4
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false

# Optimization
model_dtype: "fp16"
optim: "fusedadam" 
learning_rate: 0.1 
warmup_steps: 40 
decay_method: "noam"
adam_beta2: 0.98
max_grad_norm: 0
label_smoothing: 0.1
dropout: 0.1
param_init: 0
param_init_glorot: true
normalization: "tokens"


# Model
override_opts: true

# To make it works with Lora and override
add_qkvbias: true

encoder_type: transformer
decoder_type: transformer

enc_layers: 24          # 3.3B
dec_layers: 24          # 3.3B
transformer_ff: 8192    # 3.3B

heads: 16
hidden_size: 2048
word_vec_size: 2048
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: 'SinusoidalConcat'

Now, I can not use this trick to make it work:

Do I have something wrong in my yaml?

sorry my fault, I pushed a fix.
maybe you can try on a few steps of training to make sure it works, but should be fine hopefully.
thanks for your patience.

Thank you, now it works. I also get a slight improvement in BLEU by using add_qkvbias.

This does not work with my new Lora models. It also does not work with base models.

I am just executing this:

source_sents = [sent.strip() for sent in batch]
target_prefix = [[tgt]] * len(source_sents)

# Subword the source sentences
source_sents_subworded = tokenizer.encode_as_pieces(source_sents)
source_sents_subworded = [[src] + sent + ["</s>"] for sent in source_sents_subworded]

print(source_sents_subworded)

# Translate the source sentences
translations = model.translate_batch(
        source_sents_subworded, batch_type="tokens", max_batch_size=1024, 
        beam_size=5, target_prefix=target_prefix
)
print(translations)

I tried onmt_release_model script and got the same bad results.
I have tried to do it with ctranslate2==3.11.0 and with last version (3.13.0)

Does it work when you convert the model without quantization?

I downloaded the 600M and 1.3B again and convert them without quantization. Got the same bad results. I used 3.13.0 version of ctranslate2.

did you do the exact same thing that was working with 3.11 in the other thread?

I tried both ways (old and new) and something else. All gave me the same bad results.

Seems to be a missing flag in the converted checkpoints.
I am currently uploading corrected checkpoint.

On your end you can just try the following:

import torch
m = torch.load("your checkpoint.pt")
m['opt'].decoder_start_token='</s>'
torch.save(m, "yourcheckpoint.pt")

Then you can convert to ct2.
Let me know if it fixes the issue.

That solved it, thank you so much, let us know once the original checkpoints are fixed.

they are uploaded now.
if you get a chance to try quantization int8 and int8_float16, let us know if the results are comparable.
cheers.