Finetuning bigger models with LoRa (Low-Rank Adaptation) in OpenNMT-py

Any idea when CTranslate2 will support these LoRa adapters after the OpenNMT-py model is LoRa finetuned?

1 Like

if it’s a model already supported by CT2 then it will work out of the box because we save the full model in OpenNMT-py at training time.

For instance if you train a NLLB-200 3.3B with LoRa, it will work since the conversion from OpenNMT-py for those NLLB is already supported.

For Llama, when we will commit the PR, it should take only a few adaptations since CT2 already supports Llama architecture.

I see, I thought it would be using the (smaller) LoRa weights in place of multiple (large) finetuned models, and could just ‘call’ them with the base model being loaded in VRAM, incurring a tiny amount of additional VRAM for each finetuned model since all of them are using the same base model.

Regardless, great work!

I see your point same question was raised on Gitter.im

I think it is both a question of serving and ct2 so it is not so obvious to do.

It might be interesting to develop but honestly it won’t happen in the very short term.

I have tried to use the LoRA options, but can not see much difference.

I am trying to fine tune 3.3B NLLB model. When I do not use LoRA, model takes in GPU 13650 MiB. When I add the LoRA options to the config.yaml file, I get 13664MiB. The parameters of the models are 3.344.529.614 for normal set up vs LoRA 3.348.510.926. I can not train it even with a batch size of 1 (I tried with LoRa and without it).

For me, it seems that the size in memory of loaded model is doubled (instead of 3B x 2b it is something like 3B x 4b). The increase in LoRA parameters make sense with the increase of memory and the 3B x 4b formula that I mention.

My config.yaml

keep_checkpoint: -1
save_checkpoint_steps: 600 

average_decay: 0.0005
seed: 1234
report_every: 1
train_steps: 10200 
valid_steps: 600

# Batching
bucket_size: 262144
num_workers: 1
prefetch_factor:  400
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"


batch_size: 1                             
valid_batch_size: 1                       
batch_size_multiple: 1                      

accum_count: [4]
accum_steps: [0]

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linnear_keys', 'final_linear']
lora_rank: 4
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false

# Optimization
model_dtype: "fp16"
optim: "sgd" 
learning_rate: 45 
warmup_steps: 40 
decay_method: "noam"
adam_beta2: 0.98
max_grad_norm: 0
label_smoothing: 0.1
dropout: 0.1
param_init: 0
param_init_glorot: true
normalization: "tokens"


# Model
override_opts: true
encoder_type: transformer
decoder_type: transformer

enc_layers: 24         
dec_layers: 24         
transformer_ff: 8192    

heads: 16
hidden_size: 2048
word_vec_size: 2048
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: 'SinusoidalConcat'

Do I have to do something else to make it work? I am using OpenNMT-py==3.1.1, torch==1.13.1, Python 3.10.6 and installed pip install -r requirements.opt.txt (just in case)

There will be no gain with SGD because it is proportional to the size of the model.

You need to work with Adam.

But at the moment, I think you will need to use “fusedadam” which performs a “model.half()” from the very beginning (hence you need to install apex). BTW Apex/legacy fusedadam is the recommended optimizer anyway.

You can give a try with Adam too.

I will commit shortly a PR that saves the LoRa weights only and a tool to merge it post training.

Has Adam been confirmed to work? I’ve attempted to finetune a pre-existing model I had with Lora in OpenNMT-py with Adam and the Lora config above but parameter size seems to remain the same as well as performance / memory usage. Using 3.1.1

can you try with master please.
also post your config.

Downloaded master [I had to delete the clean transform because gcld3 installation on my Windows machine doesn’t work for some reason]

There doesn’t seem to be much difference memory or training-speed wise [tok/s]. I’m using a 12 layer encoder / 1 layer decoder with 16 heads [config file below]

Mapped out some of the parameter counts too

LoRa embeddings off: 66,020,824 parameters
LoRa embeddings on: 90,095,762 parameters
No LoRa [normal]: 65,963,480 parameters

save_data: "run"


data:
    corpus_1:
        path_src: corpus.src
        path_tgt: corpus.tgt

    valid:
        path_src: validation.src
        path_tgt: validation.tgt

# Vocabulary files, generated by onmt_build_vocab
src_vocab: run/source.vocab
tgt_vocab: run/source.vocab

# override_opts: true
tensorboard: true
tensorboard_log_dir: ./tensorboard_logs

model_dtype: "fp16"
src_vocab_size: 64000
tgt_vocab_size: 64000
share_vocab: true

override_opts: true
train_from: "./averaged_original.pt"

lora_layers: ['linear_values', 'linear_query']
lora_rank: 2
lora_dropout: 0.1
lora_alpha: 1
lora_embedding: false

aan_useffn: true
save_model: "./saves_lora/model"
save_checkpoint_steps: 3000

src_subword_model: "general_multi.model"
tgt_subword_model: "general_multi.model"

seed: 3435

train_steps: 450000
bucket_size: 600000
valid_steps: 3000
train_eval_steps: 3000
train_metrics: "BLEU"


warmup_steps: 16000
report_every: 1


decoder_type: "transformer"
encoder_type: "transformer"
word_vec_size: 512
hidden_size: 512

transformer_ff: 2048
enc_layers: 12
dec_layers: 1
heads: 16


accum_count: [4]
optim: "adam"
adam_beta1: 0.9
adam_beta2: 0.98
decay_method: "noam"
learning_rate: 2.0
max_grad_norm: 0.0


pos_ffn_activation_fn: "gelu"
context_gate: "both"
full_context_alignment: true

bridge: true

# batch_size: 8192

valid_batch_size: 8192
batch_size: 14000
batch_type: "tokens"
normalization: "tokens"
dropout: 0.2
attention_dropout: [0.1]
dropout_steps: [0]
label_smoothing: 0.1

max_generator_batches: 2

num_workers: 0 # anything else doesnt work on windows

batch_size_multiple: 8
vocab_size_multiple: 8

param_init: 0.0
param_init_glorot: true

max_relative_positions: 32
share_embeddings: true
share_decoder_embeddings: true

world_size: 1
gpu_ranks: [0]

I got a huge memory reduction by using fusedadam and LoRA. I fine-tuned a 3.3B model on a 32 GB GPU. I could not do this with Adam or without LoRA, I needed both.

The loaded model alone goes from 13.664 MiB to 7.830 MiB when I use LoRA and Fused Adam.

I had to install apex by following the GitHub guide and modifying the pip install like this:

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--deprecated_fused_adam" --global-option="--cpp_ext" --global-option="--cuda_ext"  ./

If you do not add

--global-option="--deprecated_fused_adam" 

it would not work with OpenNMT-py. I installed OpenNMT from master this morning.

I did a previous experiment with 600M model where I went from 17.070 MiB with just Adam to 10.460 MiB during training when I changed to LoRA+Adam. Just the loaded model goes from 3.210 MiB with just Adam to 3.214 MiB when I changed to LoRA+Adam. With fuse Adam it goes from 2.020 MiB with just Adam to 8.682 MiB when I changed to LoRA+Adam.

I have been using this LoRA options:

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linnear_keys', 'final_linear']
lora_rank: 4
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false

Martin,

I am still working on a PR because there is a bug in the original LoRa implementation to merge weights.
I should be done tonight.
I am testing everything with Llama 7B which should be the same procedure.

The training is OK, but I can’t inference with it, or do you think I need to repeat the training?

no but you will have to run the modified script of lora_weights.py to merge or concat.

Martin,

I just committed.
If your last training went through you should have:

  1. your base model (in the OpenNMT-py format) before finetuning

  2. your Lora Weights (much smaller file)

If you git pull master you should be able to run:

python tools/lora_weights.py --action merge --base_model nllb-onmt.pt --lora_weights your_lora.pt --output nllb-finetuned.pt

Then you can try to run inference with the new nllb-finetuned.pt

Let me know how it goes.

After doing this, I get the following error when I run inference with the new nllb-finetuned.pt:

Traceback (most recent call last):
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/bin/onmt_translate", line 33, in <module>
    sys.exit(load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_translate')())
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 60, in main
    translate(opt)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 23, in translate
    translator = build_translator(opt, logger=logger,
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/translate/translator.py", line 33, in build_translator
    vocabs, model, model_opt = load_test_model(opt)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 171, in load_test_model
    model = build_base_model(model_opt, vocabs, checkpoint)
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 402, in build_base_model
    model.load_state_dict(checkpoint['model'],
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NMTModel:
	Missing key(s) in state_dict: "encoder.transformer.0.self_attn.linear_keys.bias", "encoder.transformer.0.self_attn.linear_values.bias", "encoder.transformer.0.self_attn.linear_query.bias", "encoder.transformer.0.self_attn.final_linear.bias", "encoder.transformer.1.self_attn.linear_keys.bias", "encoder.transformer.1.self_attn.linear_values.bias", "encoder.transformer.1.self_attn.linear_query.bias", "encoder.transformer.1.self_attn.final_linear.bias", "encoder.transformer.2.self_attn.linear_keys.bias", "encoder.transformer.2.self_attn.linear_values.bias", "encoder.transformer.2.self_attn.linear_query.bias", "encoder.transformer.2.self_attn.final_linear.bias", "encoder.transformer.3.self_attn.linear_keys.bias", "encoder.transformer.3.self_attn.linear_values.bias", "encoder.transformer.3.self_attn.linear_query.bias", "encoder.transformer.3.self_attn.final_linear.bias", "encoder.transformer.4.self_attn.linear_keys.bias", "encoder.transformer.4.self_attn.linear_values.bias", "encoder.transformer.4.self_attn.linear_query.bias", "encoder.transformer.4.self_attn.final_linear.bias", "encoder.transformer.5.self_attn.linear_keys.bias", "encoder.transformer.5.self_attn.linear_values.bias", "encoder.transformer.5.self_attn.linear_query.bias", "encoder.transformer.5.self_attn.final_linear.bias", "encoder.transformer.6.self_attn.linear_keys.bias", "encoder.transformer.6.self_attn.linear_values.bias", "encoder.transformer.6.self_attn.linear_query.bias", "encoder.transformer.6.self_attn.final_linear.bias", "encoder.transformer.7.self_attn.linear_keys.bias", "encoder.transformer.7.self_attn.linear_values.bias", "encoder.transformer.7.self_attn.linear_query.bias", "encoder.transformer.7.self_attn.final_linear.bias", "encoder.transformer.8.self_attn.linear_keys.bias", "encoder.transformer.8.self_attn.linear_values.bias", "encoder.transformer.8.self_attn.linear_query.bias", "encoder.transformer.8.self_attn.final_linear.bias", "encoder.transformer.9.self_attn.linear_keys.bias", "encoder.transformer.9.self_attn.linear_values.bias", "encoder.transformer.9.self_attn.linear_query.bias", "encoder.transformer.9.self_attn.final_linear.bias", "encoder.transformer.10.self_attn.linear_keys.bias", "encoder.transformer.10.self_attn.linear_values.bias", "encoder.transformer.10.self_attn.linear_query.bias", "encoder.transformer.10.self_attn.final_linear.bias", "encoder.transformer.11.self_attn.linear_keys.bias", "encoder.transformer.11.self_attn.linear_values.bias", "encoder.transformer.11.self_attn.linear_query.bias", "encoder.transformer.11.self_attn.final_linear.bias", "encoder.transformer.12.self_attn.linear_keys.bias", "encoder.transformer.12.self_attn.linear_values.bias", "encoder.transformer.12.self_attn.linear_query.bias", "encoder.transformer.12.self_attn.final_linear.bias", "encoder.transformer.13.self_attn.linear_keys.bias", "encoder.transformer.13.self_attn.linear_values.bias", "encoder.transformer.13.self_attn.linear_query.bias", "encoder.transformer.13.self_attn.final_linear.bias", "encoder.transformer.14.self_attn.linear_keys.bias", "encoder.transformer.14.self_attn.linear_values.bias", "encoder.transformer.14.self_attn.linear_query.bias", "encoder.transformer.14.self_attn.final_linear.bias", "encoder.transformer.15.self_attn.linear_keys.bias", "encoder.transformer.15.self_attn.linear_values.bias", "encoder.transformer.15.self_attn.linear_query.bias", "encoder.transformer.15.self_attn.final_linear.bias", "encoder.transformer.16.self_attn.linear_keys.bias", "encoder.transformer.16.self_attn.linear_values.bias", "encoder.transformer.16.self_attn.linear_query.bias", "encoder.transformer.16.self_attn.final_linear.bias", "encoder.transformer.17.self_attn.linear_keys.bias", "encoder.transformer.17.self_attn.linear_values.bias", "encoder.transformer.17.self_attn.linear_query.bias", "encoder.transformer.17.self_attn.final_linear.bias", "encoder.transformer.18.self_attn.linear_keys.bias", "encoder.transformer.18.self_attn.linear_values.bias", "encoder.transformer.18.self_attn.linear_query.bias", "encoder.transformer.18.self_attn.final_linear.bias", "encoder.transformer.19.self_attn.linear_keys.bias", "encoder.transformer.19.self_attn.linear_values.bias", "encoder.transformer.19.self_attn.linear_query.bias", "encoder.transformer.19.self_attn.final_linear.bias", "encoder.transformer.20.self_attn.linear_keys.bias", "encoder.transformer.20.self_attn.linear_values.bias", "encoder.transformer.20.self_attn.linear_query.bias", "encoder.transformer.20.self_attn.final_linear.bias", "encoder.transformer.21.self_attn.linear_keys.bias", "encoder.transformer.21.self_attn.linear_values.bias", "encoder.transformer.21.self_attn.linear_query.bias", "encoder.transformer.21.self_attn.final_linear.bias", "encoder.transformer.22.self_attn.linear_keys.bias", "encoder.transformer.22.self_attn.linear_values.bias", "encoder.transformer.22.self_attn.linear_query.bias", "encoder.transformer.22.self_attn.final_linear.bias", "encoder.transformer.23.self_attn.linear_keys.bias", "encoder.transformer.23.self_attn.linear_values.bias", "encoder.transformer.23.self_attn.linear_query.bias", "encoder.transformer.23.self_attn.final_linear.bias", "decoder.transformer_layers.0.self_attn.linear_keys.bias", "decoder.transformer_layers.0.self_attn.linear_values.bias", "decoder.transformer_layers.0.self_attn.linear_query.bias", "decoder.transformer_layers.0.self_attn.final_linear.bias", "decoder.transformer_layers.0.context_attn.linear_keys.bias", "decoder.transformer_layers.0.context_attn.linear_values.bias", "decoder.transformer_layers.0.context_attn.linear_query.bias", "decoder.transformer_layers.0.context_attn.final_linear.bias", "decoder.transformer_layers.1.self_attn.linear_keys.bias", "decoder.transformer_layers.1.self_attn.linear_values.bias", "decoder.transformer_layers.1.self_attn.linear_query.bias", "decoder.transformer_layers.1.self_attn.final_linear.bias", "decoder.transformer_layers.1.context_attn.linear_keys.bias", "decoder.transformer_layers.1.context_attn.linear_values.bias", "decoder.transformer_layers.1.context_attn.linear_query.bias", "decoder.transformer_layers.1.context_attn.final_linear.bias", "decoder.transformer_layers.2.self_attn.linear_keys.bias", "decoder.transformer_layers.2.self_attn.linear_values.bias", "decoder.transformer_layers.2.self_attn.linear_query.bias", "decoder.transformer_layers.2.self_attn.final_linear.bias", "decoder.transformer_layers.2.context_attn.linear_keys.bias", "decoder.transformer_layers.2.context_attn.linear_values.bias", "decoder.transformer_layers.2.context_attn.linear_query.bias", "decoder.transformer_layers.2.context_attn.final_linear.bias", "decoder.transformer_layers.3.self_attn.linear_keys.bias", "decoder.transformer_layers.3.self_attn.linear_values.bias", "decoder.transformer_layers.3.self_attn.linear_query.bias", "decoder.transformer_layers.3.self_attn.final_linear.bias", "decoder.transformer_layers.3.context_attn.linear_keys.bias", "decoder.transformer_layers.3.context_attn.linear_values.bias", "decoder.transformer_layers.3.context_attn.linear_query.bias", "decoder.transformer_layers.3.context_attn.final_linear.bias", "decoder.transformer_layers.4.self_attn.linear_keys.bias", "decoder.transformer_layers.4.self_attn.linear_values.bias", "decoder.transformer_layers.4.self_attn.linear_query.bias", "decoder.transformer_layers.4.self_attn.final_linear.bias", "decoder.transformer_layers.4.context_attn.linear_keys.bias", "decoder.transformer_layers.4.context_attn.linear_values.bias", "decoder.transformer_layers.4.context_attn.linear_query.bias", "decoder.transformer_layers.4.context_attn.final_linear.bias", "decoder.transformer_layers.5.self_attn.linear_keys.bias", "decoder.transformer_layers.5.self_attn.linear_values.bias", "decoder.transformer_layers.5.self_attn.linear_query.bias", "decoder.transformer_layers.5.self_attn.final_linear.bias", "decoder.transformer_layers.5.context_attn.linear_keys.bias", "decoder.transformer_layers.5.context_attn.linear_values.bias", "decoder.transformer_layers.5.context_attn.linear_query.bias", "decoder.transformer_layers.5.context_attn.final_linear.bias", "decoder.transformer_layers.6.self_attn.linear_keys.bias", "decoder.transformer_layers.6.self_attn.linear_values.bias", "decoder.transformer_layers.6.self_attn.linear_query.bias", "decoder.transformer_layers.6.self_attn.final_linear.bias", "decoder.transformer_layers.6.context_attn.linear_keys.bias", "decoder.transformer_layers.6.context_attn.linear_values.bias", "decoder.transformer_layers.6.context_attn.linear_query.bias", "decoder.transformer_layers.6.context_attn.final_linear.bias", "decoder.transformer_layers.7.self_attn.linear_keys.bias", "decoder.transformer_layers.7.self_attn.linear_values.bias", "decoder.transformer_layers.7.self_attn.linear_query.bias", "decoder.transformer_layers.7.self_attn.final_linear.bias", "decoder.transformer_layers.7.context_attn.linear_keys.bias", "decoder.transformer_layers.7.context_attn.linear_values.bias", "decoder.transformer_layers.7.context_attn.linear_query.bias", "decoder.transformer_layers.7.context_attn.final_linear.bias", "decoder.transformer_layers.8.self_attn.linear_keys.bias", "decoder.transformer_layers.8.self_attn.linear_values.bias", "decoder.transformer_layers.8.self_attn.linear_query.bias", "decoder.transformer_layers.8.self_attn.final_linear.bias", "decoder.transformer_layers.8.context_attn.linear_keys.bias", "decoder.transformer_layers.8.context_attn.linear_values.bias", "decoder.transformer_layers.8.context_attn.linear_query.bias", "decoder.transformer_layers.8.context_attn.final_linear.bias", "decoder.transformer_layers.9.self_attn.linear_keys.bias", "decoder.transformer_layers.9.self_attn.linear_values.bias", "decoder.transformer_layers.9.self_attn.linear_query.bias", "decoder.transformer_layers.9.self_attn.final_linear.bias", "decoder.transformer_layers.9.context_attn.linear_keys.bias", "decoder.transformer_layers.9.context_attn.linear_values.bias", "decoder.transformer_layers.9.context_attn.linear_query.bias", "decoder.transformer_layers.9.context_attn.final_linear.bias", "decoder.transformer_layers.10.self_attn.linear_keys.bias", "decoder.transformer_layers.10.self_attn.linear_values.bias", "decoder.transformer_layers.10.self_attn.linear_query.bias", "decoder.transformer_layers.10.self_attn.final_linear.bias", "decoder.transformer_layers.10.context_attn.linear_keys.bias", "decoder.transformer_layers.10.context_attn.linear_values.bias", "decoder.transformer_layers.10.context_attn.linear_query.bias", "decoder.transformer_layers.10.context_attn.final_linear.bias", "decoder.transformer_layers.11.self_attn.linear_keys.bias", "decoder.transformer_layers.11.self_attn.linear_values.bias", "decoder.transformer_layers.11.self_attn.linear_query.bias", "decoder.transformer_layers.11.self_attn.final_linear.bias", "decoder.transformer_layers.11.context_attn.linear_keys.bias", "decoder.transformer_layers.11.context_attn.linear_values.bias", "decoder.transformer_layers.11.context_attn.linear_query.bias", "decoder.transformer_layers.11.context_attn.final_linear.bias", "decoder.transformer_layers.12.self_attn.linear_keys.bias", "decoder.transformer_layers.12.self_attn.linear_values.bias", "decoder.transformer_layers.12.self_attn.linear_query.bias", "decoder.transformer_layers.12.self_attn.final_linear.bias", "decoder.transformer_layers.12.context_attn.linear_keys.bias", "decoder.transformer_layers.12.context_attn.linear_values.bias", "decoder.transformer_layers.12.context_attn.linear_query.bias", "decoder.transformer_layers.12.context_attn.final_linear.bias", "decoder.transformer_layers.13.self_attn.linear_keys.bias", "decoder.transformer_layers.13.self_attn.linear_values.bias", "decoder.transformer_layers.13.self_attn.linear_query.bias", "decoder.transformer_layers.13.self_attn.final_linear.bias", "decoder.transformer_layers.13.context_attn.linear_keys.bias", "decoder.transformer_layers.13.context_attn.linear_values.bias", "decoder.transformer_layers.13.context_attn.linear_query.bias", "decoder.transformer_layers.13.context_attn.final_linear.bias", "decoder.transformer_layers.14.self_attn.linear_keys.bias", "decoder.transformer_layers.14.self_attn.linear_values.bias", "decoder.transformer_layers.14.self_attn.linear_query.bias", "decoder.transformer_layers.14.self_attn.final_linear.bias", "decoder.transformer_layers.14.context_attn.linear_keys.bias", "decoder.transformer_layers.14.context_attn.linear_values.bias", "decoder.transformer_layers.14.context_attn.linear_query.bias", "decoder.transformer_layers.14.context_attn.final_linear.bias", "decoder.transformer_layers.15.self_attn.linear_keys.bias", "decoder.transformer_layers.15.self_attn.linear_values.bias", "decoder.transformer_layers.15.self_attn.linear_query.bias", "decoder.transformer_layers.15.self_attn.final_linear.bias", "decoder.transformer_layers.15.context_attn.linear_keys.bias", "decoder.transformer_layers.15.context_attn.linear_values.bias", "decoder.transformer_layers.15.context_attn.linear_query.bias", "decoder.transformer_layers.15.context_attn.final_linear.bias", "decoder.transformer_layers.16.self_attn.linear_keys.bias", "decoder.transformer_layers.16.self_attn.linear_values.bias", "decoder.transformer_layers.16.self_attn.linear_query.bias", "decoder.transformer_layers.16.self_attn.final_linear.bias", "decoder.transformer_layers.16.context_attn.linear_keys.bias", "decoder.transformer_layers.16.context_attn.linear_values.bias", "decoder.transformer_layers.16.context_attn.linear_query.bias", "decoder.transformer_layers.16.context_attn.final_linear.bias", "decoder.transformer_layers.17.self_attn.linear_keys.bias", "decoder.transformer_layers.17.self_attn.linear_values.bias", "decoder.transformer_layers.17.self_attn.linear_query.bias", "decoder.transformer_layers.17.self_attn.final_linear.bias", "decoder.transformer_layers.17.context_attn.linear_keys.bias", "decoder.transformer_layers.17.context_attn.linear_values.bias", "decoder.transformer_layers.17.context_attn.linear_query.bias", "decoder.transformer_layers.17.context_attn.final_linear.bias", "decoder.transformer_layers.18.self_attn.linear_keys.bias", "decoder.transformer_layers.18.self_attn.linear_values.bias", "decoder.transformer_layers.18.self_attn.linear_query.bias", "decoder.transformer_layers.18.self_attn.final_linear.bias", "decoder.transformer_layers.18.context_attn.linear_keys.bias", "decoder.transformer_layers.18.context_attn.linear_values.bias", "decoder.transformer_layers.18.context_attn.linear_query.bias", "decoder.transformer_layers.18.context_attn.final_linear.bias", "decoder.transformer_layers.19.self_attn.linear_keys.bias", "decoder.transformer_layers.19.self_attn.linear_values.bias", "decoder.transformer_layers.19.self_attn.linear_query.bias", "decoder.transformer_layers.19.self_attn.final_linear.bias", "decoder.transformer_layers.19.context_attn.linear_keys.bias", "decoder.transformer_layers.19.context_attn.linear_values.bias", "decoder.transformer_layers.19.context_attn.linear_query.bias", "decoder.transformer_layers.19.context_attn.final_linear.bias", "decoder.transformer_layers.20.self_attn.linear_keys.bias", "decoder.transformer_layers.20.self_attn.linear_values.bias", "decoder.transformer_layers.20.self_attn.linear_query.bias", "decoder.transformer_layers.20.self_attn.final_linear.bias", "decoder.transformer_layers.20.context_attn.linear_keys.bias", "decoder.transformer_layers.20.context_attn.linear_values.bias", "decoder.transformer_layers.20.context_attn.linear_query.bias", "decoder.transformer_layers.20.context_attn.final_linear.bias", "decoder.transformer_layers.21.self_attn.linear_keys.bias", "decoder.transformer_layers.21.self_attn.linear_values.bias", "decoder.transformer_layers.21.self_attn.linear_query.bias", "decoder.transformer_layers.21.self_attn.final_linear.bias", "decoder.transformer_layers.21.context_attn.linear_keys.bias", "decoder.transformer_layers.21.context_attn.linear_values.bias", "decoder.transformer_layers.21.context_attn.linear_query.bias", "decoder.transformer_layers.21.context_attn.final_linear.bias", "decoder.transformer_layers.22.self_attn.linear_keys.bias", "decoder.transformer_layers.22.self_attn.linear_values.bias", "decoder.transformer_layers.22.self_attn.linear_query.bias", "decoder.transformer_layers.22.self_attn.final_linear.bias", "decoder.transformer_layers.22.context_attn.linear_keys.bias", "decoder.transformer_layers.22.context_attn.linear_values.bias", "decoder.transformer_layers.22.context_attn.linear_query.bias", "decoder.transformer_layers.22.context_attn.final_linear.bias", "decoder.transformer_layers.23.self_attn.linear_keys.bias", "decoder.transformer_layers.23.self_attn.linear_values.bias", "decoder.transformer_layers.23.self_attn.linear_query.bias", "decoder.transformer_layers.23.self_attn.final_linear.bias", "decoder.transformer_layers.23.context_attn.linear_keys.bias", "decoder.transformer_layers.23.context_attn.linear_values.bias", "decoder.transformer_layers.23.context_attn.linear_query.bias", "decoder.transformer_layers.23.context_attn.final_linear.bias". 

I am using the following inference.yaml:

batch_size: 8192
batch_type: tokens
beam_size: 5
fp16: null
gpu: 0
log_file: translate.log
max_length: 512
model: test_3_3B/nllb-200-lora-3_3B_step_10200.pt
report_time: true
src_prefix: </s> eng_Latn
src_subword_alpha: 0.0
src_subword_model: flores200_sacrebleu_tokenizer_spm.model
src_subword_nbest: 1
src_suffix: ''
tgt_file_prefix: true
tgt_prefix: spa_Latn
tgt_subword_alpha: 0.0
tgt_subword_model: flores200_sacrebleu_tokenizer_spm.model
tgt_subword_nbest: 1
tgt_suffix: ''
transforms:
- sentencepiece
- prefix
- suffix

I installed master 1 hour ago:
-e git+https://github.com/OpenNMT/OpenNMT-py.git@07534c5b9a181d24165ab218fda986e1fff0fef4#egg=OpenNMT_py

can you post your finetuning config ?

maybe you had override_opts: true but did not have add_qkvbias.

looking above I think this is the issue.

if you want to try to infer without retraining do the following:

Open a Python (or better ipython) console.

import torch
m = torch.load("test_3_3B/nllb-200-lora-3_3B_step_10200.pt")
m['opt'].add_qkvbias=False
torch.save(m, "test_3_3B/nllb-200-lora-3_3B_step_10200.pt")

Then your model will be consistent (no QKV bias and option set to false)

you can try inference.

It was the case, yes.

This fixed my model, now it works.

I got 1.2 BLEU more with the first experiment with 3.3B and this setup than with my best setup with 1.3B model using same GPU memory.

great.
bear in mind that your finetuning may not have been optimal without QKV biases because the original model had those biases. you may try to redo it with add_qkvbias=True

also you can try to convert your model wit CT2 and quantize to int8 (cf doc) to check if this is much faster.

I willa lso add an option to load the model in int8 mode in OpenNMT-py (with bnb) but it will not accelerate, just fit in memory for smaller GPUs.

After add it in the training and do the merge I get the following error.

Traceback (most recent call last):                                                                                                                                                                         
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/bin/onmt_translate", line 33, in <module>                                                                                                                 
    sys.exit(load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_translate')())                                                                                                                        
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 60, in main                                                                                                                   
    translate(opt)                                                                                                                                                                                         
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/bin/translate.py", line 23, in translate                                                                                                              
    translator = build_translator(opt, logger=logger,                                                                                                                                                      
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/translate/translator.py", line 33, in build_translator                                                                                                
    vocabs, model, model_opt = load_test_model(opt)                                                                                                                                                        
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 171, in load_test_model                                                                                                       
    model = build_base_model(model_opt, vocabs, checkpoint)                                                                                                                                                
  File "/home/m.barroso/OpenNMT_nllb/OpenNMT-py/onmt/model_builder.py", line 402, in build_base_model                                                                                                      
    model.load_state_dict(checkpoint['model'],                                                                                                                                                             
  File "/home/m.barroso/anaconda3/envs/nllb_3_3B/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict                                                                   
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                                                                                                                              
RuntimeError: Error(s) in loading state_dict for NMTModel:                                                                                                                                                 
        Missing key(s) in state_dict: "encoder.transformer.0.self_attn.linear_values.bias", "encoder.transformer.0.self_attn.linear_query.bias", "encoder.transformer.0.self_attn.final_linear.bias", "enco
der.transformer.1.self_attn.linear_values.bias", "encoder.transformer.1.self_attn.linear_query.bias", "encoder.transformer.1.self_attn.final_linear.bias", "encoder.transformer.2.self_attn.linear_values.b
ias", "encoder.transformer.2.self_attn.linear_query.bias", "encoder.transformer.2.self_attn.final_linear.bias", "encoder.transformer.3.self_attn.linear_values.bias", "encoder.transformer.3.self_attn.line
ar_query.bias", "encoder.transformer.3.self_attn.final_linear.bias", "encoder.transformer.4.self_attn.linear_values.bias", "encoder.transformer.4.self_attn.linear_query.bias", "encoder.transformer.4.self
_attn.final_linear.bias", "encoder.transformer.5.self_attn.linear_values.bias"...

My trainin yaml is this one:

# Vocab creation options
share_vocab: true

## Where the vocab(s) is
src_vocab: "dictionary.txt"
src_words_min_frequency: 1
src_vocab_size: 256206

tgt_vocab: "dictionary.txt"
tgt_words_min_frequency: 1
tgt_vocab_size: 256206

vocab_size_multiple: 1

decoder_start_token: '</s>'


### Transform related opts:

#### Subword
src_subword_model: "flores200_sacrebleu_tokenizer_spm.model"
tgt_subword_model: "flores200_sacrebleu_tokenizer_spm.model"
src_subword_nbest: 1
src_subword_alpha: 0.0
tgt_subword_nbest: 1
tgt_subword_alpha: 0.0
  


#### Filter
src_seq_length: 150
tgt_seq_length: 150

# silently ignore empty lines in the data
#skip_empty_level: silent

# General opts
update_vocab: true # option to perform vocabulary update

train_from: "nllb-200-3.3B-onmt.pt"           # 3.3B

reset_optim: all # states option to perform vocabulary update
save_data: "/nllb-200"
save_model: "trained_models_3_3B_es_en_test/nllb-200-600M-onmt-1"
log_file: "train.log"

keep_checkpoint: -1
save_checkpoint_steps: 100 # 200 500

average_decay: 0.0005
seed: 1234
report_every: 1
train_steps: 10200 # 2000 100000
valid_steps: 600 # 400 5000

# Batching
bucket_size: 262144
num_workers: 1
prefetch_factor:  400
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"


batch_size: 1024                              # 3.3B
valid_batch_size: 1024                        # 3.3B
batch_size_multiple: 2                        # 3.3B

accum_count: [2]
accum_steps: [0]

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linnear_keys', 'final_linear']
lora_rank: 4
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false

# Optimization
model_dtype: "fp16"
optim: "fusedadam" 
learning_rate: 0.1 
warmup_steps: 40 
decay_method: "noam"
adam_beta2: 0.98
max_grad_norm: 0
label_smoothing: 0.1
dropout: 0.1
param_init: 0
param_init_glorot: true
normalization: "tokens"


# Model
override_opts: true

# To make it works with Lora and override
add_qkvbias: true

encoder_type: transformer
decoder_type: transformer

enc_layers: 24          # 3.3B
dec_layers: 24          # 3.3B
transformer_ff: 8192    # 3.3B

heads: 16
hidden_size: 2048
word_vec_size: 2048
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: 'SinusoidalConcat'

Now, I can not use this trick to make it work:

Do I have something wrong in my yaml?

sorry my fault, I pushed a fix.
maybe you can try on a few steps of training to make sure it works, but should be fine hopefully.
thanks for your patience.