NLLB-200 with CTranslate2

I updated the 600M checkpoint on S3.

Key options are below (the rest you can use whatever you use, adam, …)

share_vocab: true
src_vocab: “/nllb-200/dictionary.txt”
src_words_min_frequency: 1
src_vocab_size: 257000
tgt_vocab: “/nllb-200/dictionary.txt”
tgt_words_min_frequency: 1
tgt_vocab_size: 257000
src_vocab_multiple: 8

Corpus opts:

data:
mydataset:
path_src: “/en-de/cc-matrix-ende.en”
path_tgt: “/en-de/cc-matrix-ende.de”
transforms: [sentencepiece, prefix, suffix, filtertoolong]
weight: 10
src_prefix: “”
tgt_prefix: “deu_Latn”
src_suffix: “ eng_Latn”
tgt_suffix: “”
update_vocab: true
train_from: “/nllb-200/nllb-200-600M-onmt.pt”
reset_optim: all
save_data: “/nllb-200”
save_model: “/nllb-200/nllb-200-600M-onmt”
decoder_start_token: ‘’

Subword

src_subword_model: “/nllb-200/flores200_sacrebleu_tokenizer_spm.model”
tgt_subword_model: “/nllb-200/flores200_sacrebleu_tokenizer_spm.model”
encoder_type: transformer
decoder_type: transformer
enc_layers: 12
dec_layers: 12
heads: 16
hidden_size: 1024
word_vec_size: 1024
transformer_ff: 4096
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: ‘SinusoidalConcat’

I did some experiments using train, dev and test splits of some medical data and the testing of the base model gives already good results. 58 chrf and 38 BLEU in EN > ES. I am using the 600M nllb model.

As I see, you first use onmt_build_vocab to update the vocabulary with your new tokens, and then you onmt_train.

Is there a way to skip the building vocab step? I am getting really high perplexity values at the beginning of the training phase (13k ppl and 9.5 xent in the first step). Which have no sense at all. I would like to check is that is related to this vocab step or not. As I know, the model itself has the vocab in the .pt file. But train script ask me for -src_vocab argument anyway, it says that is required.

I might be missing something, I have not a lot of experience to fine-tune pretrained models in this framework.

I did not use build_vocab

You have to provide the vocab and spm model.

https://opennmt-models.s3.amazonaws.com/nllb-200/dictionary.txt
https://opennmt-models.s3.amazonaws.com/nllb-200/flores200_sacrebleu_tokenizer_spm.model

I have followed those steps and I keep getting so high perplexity values at the beginning of the training. Does that make sense? I am using flores200 test set just to check if the fine-tuning works.

I use the following config file:

Vocab options

share_vocab: true

Where the vocab(s) is

src_vocab: “dictionary.txt”
tgt_vocab: “dictionary.txt”

src_words_min_frequency: 1
src_vocab_size: 257000

tgt_words_min_frequency: 1
tgt_vocab_size: 257000

src_vocab_multiple: 8

save_data: “nllb-200”

Corpus opts:

data:

corpus_1:
path_src: “flores200_dataset/devtest/eng_Latn.devtest”
path_tgt: “flores200_dataset/devtest/spa_Latn.devtest”
transforms: [sentencepiece, prefix, suffix, filtertoolong]
weight: 10
src_prefix: “”
tgt_prefix: “spa_Latn”
src_suffix: " eng_Latn"
tgt_suffix: “”

Subword

src_subword_model: “flores200_sacrebleu_tokenizer_spm.model”
tgt_subword_model: “flores200_sacrebleu_tokenizer_spm.model”

General opts

save_model: “trained_models/nllb-200-600M-onmt-1”

train_from: “nllb-200-600M-onmt.pt”

update_vocab: true
reset_optim: all

Model

encoder_type: transformer
decoder_type: transformer
enc_layers: 12
dec_layers: 12
heads: 16
hidden_size: 1024
word_vec_size: 1024
transformer_ff: 4096
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: ‘SinusoidalConcat’

decoder_start_token: ‘’

NEW OPTIONS

Filter

src_seq_length: 200
tgt_seq_length: 200

report_every: 1
train_steps: 2500
valid_steps: 500
save_checkpoint_steps: 250
log_file: “train.log”

Batching

bucket_size: 262144
world_size: 1
gpu_ranks: [0]
num_workers: 1
batch_type: “tokens”
batch_size: 1024
valid_batch_size: 2048
batch_size_multiple: 4
accum_count: [12]
accum_steps: [0]

Optimization

optim: “sgd”
learning_rate: 0.05
label_smoothing: 0.1

The following command:

onmt_train -config config.yaml

And I am getting the following training logs:

[2023-02-22 08:25:04,273 INFO] Parsed 1 corpora from -data.
[2023-02-22 08:25:04,274 INFO] Loading checkpoint from nllb-200-600M-onmt.pt
[2023-02-22 08:25:06,588 WARNING] configured transforms is different from checkpoint: +{‘suffix’, ‘sentencepiece’, ‘prefix’}
[2023-02-22 08:25:06,588 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-22 08:25:06,588 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-22 08:25:06,588 INFO] Get prefix for src infer:
[2023-02-22 08:25:06,588 INFO] Get prefix for tgt infer:
[2023-02-22 08:25:06,588 INFO] Get special vocabs from Transforms: {‘src’: [‘eng_Latn’], ‘tgt’: [‘spa_Latn’]}.
[2023-02-22 08:25:07,440 INFO] Updating checkpoint vocabulary with new vocabulary
[2023-02-22 08:25:07,443 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-22 08:25:07,446 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-22 08:25:07,449 INFO] Get prefix for src infer:
[2023-02-22 08:25:07,452 INFO] Get prefix for tgt infer:
[2023-02-22 08:25:07,455 INFO] Get special vocabs from Transforms: {‘src’: [‘eng_Latn’], ‘tgt’: [‘spa_Latn’]}.
[2023-02-22 08:25:08,435 INFO] Building model…
[2023-02-22 08:25:28,822 INFO] Updating vocabulary embeddings with checkpoint embeddings
[2023-02-22 08:32:45,470 INFO] src: 2 new tokens
[2023-02-22 08:40:24,422 INFO] tgt: 2 new tokens
[2023-02-22 08:40:29,150 INFO] NMTModel(
(encoder): TransformerEncoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(256208, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)

(11): TransformerEncoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
(decoder): TransformerDecoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(256208, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(transformer_layers): ModuleList(
(0): TransformerDecoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(drop): Dropout(p=0.1, inplace=False)
(context_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)

(11): TransformerDecoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(drop): Dropout(p=0.1, inplace=False)
(context_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
)
)
(generator): Linear(in_features=1024, out_features=256208, bias=True)
)
[2023-02-22 08:40:29,160 INFO] encoder: 413513728
[2023-02-22 08:40:29,160 INFO] decoder: 201818320
[2023-02-22 08:40:29,160 INFO] * number of parameters: 615332048
[2023-02-22 08:40:29,160 INFO] * src vocab size = 256208
[2023-02-22 08:40:29,160 INFO] * tgt vocab size = 256208
[2023-02-22 08:40:29,163 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-22 08:40:29,369 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-22 08:40:29,369 INFO] Get prefix for src infer:
[2023-02-22 08:40:29,369 INFO] Get prefix for tgt infer:
[2023-02-22 08:40:29,369 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-22 08:40:29,546 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-22 08:40:29,546 INFO] Get prefix for src infer:
[2023-02-22 08:40:29,546 INFO] Get prefix for tgt infer:
[2023-02-22 08:40:29,676 INFO] Starting training on GPU: [0]
[2023-02-22 08:40:29,676 INFO] Start training loop without validation…
[2023-02-22 08:40:29,676 INFO] Scoring with: TransformPipe()
[2023-02-22 08:41:46,168 INFO] Step 1/ 2500; acc: 1.3; ppl: 102504.3; xent: 11.5; lr: 0.05000; sents: 268; bsz: 727/ 920/22; 114/144 tok/s; 76 sec;
[2023-02-22 08:41:49,316 INFO] Step 2/ 2500; acc: 3.6; ppl: 9609.7; xent: 9.2; lr: 0.05000; sents: 284; bsz: 727/ 928/24; 2770/3538 tok/s; 80 sec;
[2023-02-22 08:41:52,439 INFO] Step 3/ 2500; acc: 4.2; ppl: 6028.7; xent: 8.7; lr: 0.05000; sents: 292; bsz: 726/ 918/24; 2789/3529 tok/s; 83 sec;
[2023-02-22 08:41:55,574 INFO] Step 4/ 2500; acc: 5.0; ppl: 4079.2; xent: 8.3; lr: 0.05000; sents: 296; bsz: 730/ 922/25; 2796/3529 tok/s; 86 sec;
[2023-02-22 08:41:58,978 INFO] Step 5/ 2500; acc: 7.8; ppl: 1338.7; xent: 7.2; lr: 0.05000; sents: 284; bsz: 744/ 937/24; 2624/3303 tok/s; 89 sec;
[2023-02-22 08:42:02,087 INFO] Step 6/ 2500; acc: 11.9; ppl: 687.4; xent: 6.5; lr: 0.05000; sents: 280; bsz: 728/ 920/23; 2811/3551 tok/s; 92 sec;
[2023-02-22 08:42:05,226 INFO] Step 7/ 2500; acc: 18.3; ppl: 363.5; xent: 5.9; lr: 0.05000; sents: 280; bsz: 730/ 934/23; 2793/3572 tok/s; 96 sec;
[2023-02-22 08:42:08,374 INFO] Step 8/ 2500; acc: 24.1; ppl: 243.5; xent: 5.5; lr: 0.05000; sents: 316; bsz: 738/ 922/26; 2813/3515 tok/s; 99 sec;
[2023-02-22 08:42:11,476 INFO] Step 9/ 2500; acc: 24.1; ppl: 233.5; xent: 5.5; lr: 0.05000; sents: 272; bsz: 725/ 912/23; 2805/3530 tok/s; 102 sec;
[2023-02-22 08:42:14,585 INFO] Step 10/ 2500; acc: 27.9; ppl: 186.3; xent: 5.2; lr: 0.05000; sents: 288; bsz: 733/ 926/24; 2829/3575 tok/s; 105 sec;
[2023-02-22 08:42:17,623 INFO] Step 11/ 2500; acc: 29.3; ppl: 162.3; xent: 5.1; lr: 0.05000; sents: 264; bsz: 719/ 908/22; 2842/3589 tok/s; 108 sec;

I don’t think SGD is the one to use, also you need to accumulate batches because small batches might not be in the line of the current weights.

try those:

Batching

bucket_size: 262144
num_workers: 4
prefetch_factor: 400
world_size: 1
gpu_ranks: [0]
batch_type: “tokens”
batch_size: 1024
valid_batch_size: 512
batch_size_multiple: 1
accum_count: [32, 32, 32]
accum_steps: [0, 15000, 30000]

Optimization

model_dtype: “fp16”
optim: “adam”
learning_rate: 2
warmup_steps: 4000
decay_method: “noam”
adam_beta2: 0.998
max_grad_norm: 0
label_smoothing: 0.1
param_init: 0
param_init_glorot: true
normalization: “tokens”

I keep getting similar results. I can do some fine tuning for a while and see good results for a test that is similar to the training data.
The thing is that it seems to me that I am only using the nllb structure and not the learned weights. Getting high losses at the beginning with flores200 is not normal. Is it?

[2023-02-23 08:01:08,367 INFO] Parsed 1 corpora from -data.
[2023-02-23 08:01:08,367 INFO] Loading checkpoint from nllb-200-600M-onmt.pt
[2023-02-23 08:01:10,336 WARNING] configured transforms is different from checkpoint: +{‘suffix’, ‘prefix’, ‘sentencepiece’}
[2023-02-23 08:01:10,337 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-23 08:01:10,337 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-23 08:01:10,337 INFO] Get prefix for src infer:
[2023-02-23 08:01:10,337 INFO] Get prefix for tgt infer:
[2023-02-23 08:01:10,337 INFO] Get special vocabs from Transforms: {‘src’: [‘eng_Latn’], ‘tgt’: [‘spa_Latn’]}.
[2023-02-23 08:01:11,036 INFO] Updating checkpoint vocabulary with new vocabulary
[2023-02-23 08:01:11,037 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-23 08:01:11,039 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-23 08:01:11,040 INFO] Get prefix for src infer:
[2023-02-23 08:01:11,042 INFO] Get prefix for tgt infer:
[2023-02-23 08:01:11,043 INFO] Get special vocabs from Transforms: {‘src’: [‘eng_Latn’], ‘tgt’: [‘spa_Latn’]}.
[2023-02-23 08:01:11,894 INFO] Building model…
[2023-02-23 08:01:30,847 INFO] Updating vocabulary embeddings with checkpoint embeddings
[2023-02-23 08:08:38,206 INFO] src: 2 new tokens
[2023-02-23 08:16:12,937 INFO] tgt: 2 new tokens
[2023-02-23 08:16:20,105 INFO] NMTModel(
(encoder): TransformerEncoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(256208, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)

(11): TransformerEncoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
(decoder): TransformerDecoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(256208, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(transformer_layers): ModuleList(
(0): TransformerDecoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(drop): Dropout(p=0.1, inplace=False)
(context_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)

(11): TransformerDecoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(drop): Dropout(p=0.1, inplace=False)
(context_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=True)
(linear_values): Linear(in_features=1024, out_features=1024, bias=True)
(linear_query): Linear(in_features=1024, out_features=1024, bias=True)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
)
)
(generator): Linear(in_features=1024, out_features=256208, bias=True)
)
[2023-02-23 08:16:20,115 INFO] encoder: 413513728
[2023-02-23 08:16:20,115 INFO] decoder: 201818320
[2023-02-23 08:16:20,115 INFO] * number of parameters: 615332048
[2023-02-23 08:16:20,115 INFO] * src vocab size = 256208
[2023-02-23 08:16:20,115 INFO] * tgt vocab size = 256208
[2023-02-23 08:16:20,120 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-23 08:16:20,120 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-23 08:16:20,120 INFO] Get prefix for src infer:
[2023-02-23 08:16:20,120 INFO] Get prefix for tgt infer:
[2023-02-23 08:16:20,335 INFO] Get suffix for corpus_1: {‘src’: ’ eng_Latn’, ‘tgt’: ‘’}
[2023-02-23 08:16:20,336 INFO] Get prefix for corpus_1: {‘src’: ‘’, ‘tgt’: ‘spa_Latn’}
[2023-02-23 08:16:20,336 INFO] Get prefix for src infer:
[2023-02-23 08:16:20,336 INFO] Get prefix for tgt infer:
[2023-02-23 08:16:20,693 INFO] Starting training on GPU: [0]
[2023-02-23 08:16:20,694 INFO] Start training loop without validation…
[2023-02-23 08:16:20,694 INFO] Scoring with: TransformPipe()
[2023-02-23 08:17:55,460 INFO] Step 1/ 2500; acc: 1.4; ppl: 97656.5; xent: 11.5; lr: 0.00000; sents: 700; bsz: 684/ 866/22; 231/292 tok/s; 95 sec;
[2023-02-23 08:19:12,669 INFO] Step 2/ 2500; acc: 1.3; ppl: 92510.4; xent: 11.4; lr: 0.00000; sents: 677; bsz: 705/ 894/21; 292/370 tok/s; 172 sec;
[2023-02-23 08:19:18,157 INFO] Step 3/ 2500; acc: 1.3; ppl: 101661.7; xent: 11.5; lr: 0.00000; sents: 720; bsz: 705/ 886/22; 4109/5168 tok/s; 177 sec;
[2023-02-23 08:20:38,075 INFO] Step 4/ 2500; acc: 1.4; ppl: 101701.7; xent: 11.5; lr: 0.00000; sents: 808; bsz: 714/ 898/25; 286/360 tok/s; 257 sec;
[2023-02-23 08:21:44,974 INFO] Step 5/ 2500; acc: 1.2; ppl: 94642.0; xent: 11.5; lr: 0.00000; sents: 674; bsz: 689/ 875/21; 329/419 tok/s; 324 sec;
[2023-02-23 08:22:09,613 INFO] Step 6/ 2500; acc: 1.3; ppl: 98980.9; xent: 11.5; lr: 0.00000; sents: 683; bsz: 705/ 890/21; 916/1156 tok/s; 349 sec;
[2023-02-23 08:23:11,690 INFO] Step 7/ 2500; acc: 1.3; ppl: 100228.3; xent: 11.5; lr: 0.00000; sents: 740; bsz: 707/ 891/23; 365/459 tok/s; 411 sec;
[2023-02-23 08:24:25,100 INFO] Step 8/ 2500; acc: 1.3; ppl: 96861.0; xent: 11.5; lr: 0.00000; sents: 667; bsz: 714/ 910/21; 311/397 tok/s; 484 sec;
[2023-02-23 08:24:50,678 INFO] Step 9/ 2500; acc: 1.3; ppl: 90289.4; xent: 11.4; lr: 0.00000; sents: 760; bsz: 704/ 888/24; 880/1111 tok/s; 510 sec;
[2023-02-23 08:25:36,562 INFO] Step 10/ 2500; acc: 1.5; ppl: 87589.0; xent: 11.4; lr: 0.00000; sents: 644; bsz: 680/ 862/20; 474/601 tok/s; 556 sec;
[2023-02-23 08:26:48,038 INFO] Step 11/ 2500; acc: 1.4; ppl: 89265.1; xent: 11.4; lr: 0.00000; sents: 812; bsz: 715/ 896/25; 320/401 tok/s; 627 sec;
[2023-02-23 08:26:52,863 INFO] Step 12/ 2500; acc: 1.6; ppl: 74492.3; xent: 11.2; lr: 0.00000; sents: 709; bsz: 727/ 922/22; 4822/6115 tok/s; 632 sec;
[2023-02-23 08:28:00,292 INFO] Step 13/ 2500; acc: 1.6; ppl: 67707.8; xent: 11.1; lr: 0.00000; sents: 669; bsz: 676/ 856/21; 321/406 tok/s; 700 sec;
[2023-02-23 08:29:11,725 INFO] Step 14/ 2500; acc: 1.4; ppl: 71751.9; xent: 11.2; lr: 0.00000; sents: 693; bsz: 687/ 869/22; 308/389 tok/s; 771 sec;
[2023-02-23 08:29:16,466 INFO] Step 15/ 2500; acc: 1.7; ppl: 60757.3; xent: 11.0; lr: 0.00000; sents: 681; bsz: 678/ 861/21; 4576/5814 tok/s; 776 sec;
[2023-02-23 08:30:23,037 INFO] Step 16/ 2500; acc: 1.7; ppl: 58641.2; xent: 11.0; lr: 0.00000; sents: 788; bsz: 728/ 913/25; 350/439 tok/s; 842 sec;
[2023-02-23 08:31:34,815 INFO] Step 17/ 2500; acc: 2.1; ppl: 36724.6; xent: 10.5; lr: 0.00000; sents: 690; bsz: 710/ 898/22; 317/400 tok/s; 914 sec;
[2023-02-23 08:31:39,694 INFO] Step 18/ 2500; acc: 2.1; ppl: 34736.9; xent: 10.5; lr: 0.00000; sents: 699; bsz: 706/ 894/22; 4632/5863 tok/s; 919 sec;
[2023-02-23 08:32:47,880 INFO] Step 19/ 2500; acc: 2.2; ppl: 30110.3; xent: 10.3; lr: 0.00000; sents: 684; bsz: 677/ 863/21; 318/405 tok/s; 987 sec;
[2023-02-23 08:33:59,701 INFO] Step 20/ 2500; acc: 2.2; ppl: 28336.6; xent: 10.3; lr: 0.00001; sents: 795; bsz: 738/ 922/25; 329/411 tok/s; 1059 sec;
[2023-02-23 08:34:04,615 INFO] Step 21/ 2500; acc: 2.6; ppl: 23510.7; xent: 10.1; lr: 0.00001; sents: 670; bsz: 681/ 860/21; 4432/5603 tok/s; 1064 sec;
[2023-02-23 08:35:11,330 INFO] Step 22/ 2500; acc: 2.4; ppl: 23757.8; xent: 10.1; lr: 0.00001; sents: 690; bsz: 668/ 850/22; 320/407 tok/s; 1131 sec;
[2023-02-23 08:36:23,794 INFO] Step 23/ 2500; acc: 3.0; ppl: 13052.1; xent: 9.5; lr: 0.00001; sents: 712; bsz: 734/ 925/22; 324/409 tok/s; 1203 sec;
[2023-02-23 08:36:28,659 INFO] Step 24/ 2500; acc: 3.6; ppl: 9205.9; xent: 9.1; lr: 0.00001; sents: 704; bsz: 690/ 874/22; 4539/5746 tok/s; 1208 sec;
[2023-02-23 08:37:38,313 INFO] Step 25/ 2500; acc: 3.4; ppl: 8910.7; xent: 9.1; lr: 0.00001; sents: 712; bsz: 728/ 920/22; 335/423 tok/s; 1278 sec;
[2023-02-23 08:38:51,074 INFO] Step 26/ 2500; acc: 3.5; ppl: 7804.1; xent: 9.0; lr: 0.00001; sents: 715; bsz: 681/ 865/22; 300/381 tok/s; 1350 sec;
[2023-02-23 08:38:55,949 INFO] Step 27/ 2500; acc: 3.6; ppl: 7562.1; xent: 8.9; lr: 0.00001; sents: 714; bsz: 710/ 902/22; 4660/5917 tok/s; 1355 sec;

Can you share an extract of your data?
When I do it on EN to DE I am getting not so high ppl at the beg but dropping very quickly.

I am using the flores200 test set (it is just to check if numbers make sense). Do you get a reasonable acc/ppl?

The ppl drops quickly, I have tried to use a medical dataset and it gives excellent results at the end (with these high values at the beginning), the thing is that it completely forgets everything else. So I wonder: Am I training from scratch? I do not know why ppl is so high at the beginning.

@martin_bombin
There was a bug for training NLLB-200.

I will commit a fix and new files (the .pt and the dictionnary.txt) later today.

Nice, I can not download the dictionary and the model now.

I am using this:
wget --trust-server-names https://s3.amazonaws.com/opennmt-models/nllb-200/nllb-200-600M-onmt.pt
wget --trust-server-names https://opennmt-models.s3.amazonaws.com/nllb-200/dictionary.txt

And I am getting this:

–2023-03-20 12:06:40-- https://s3.amazonaws.com/opennmt-models/nllb-200/nllb-200-600M-onmt.pt
Resolviendo s3.amazonaws.com (s3.amazonaws.com)… 52.217.196.160, 52.217.198.32, 54.231.202.136, …
Conectando con s3.amazonaws.com (s3.amazonaws.com)[52.217.196.160]:443… conectado.
Petición HTTP enviada, esperando respuesta… 403 Forbidden
2023-03-20 12:06:41 ERROR 403: Forbidden.

sorry, this is fixed, public now.