Finetuning and Curating NLLB-200 with OpenNMT-py
I followed this tutorial to add the Ge’ez language to the NLLB model which is not originally in the model.
I first trained a sentence piece model.
spm.SentencePieceTrainer.train(input='gmmt/shared_train.txt',
model_prefix='shared_gmmt_spm',
vocab_size=8000, model_type='bpe')
I then build the vocab using the trained spm model and the opennmt build vocab tool to have the vocab in opennmt format(may be not important to do this ). Here is the config I using to build the vocab.
share_vocab: true
src_vocab: "gmmt/dictionary1.txt"
src_words_min_frequency: 1
src_vocab_size: 256232
tgt_vocab: "gmmt/dictionary1.txt"
tgt_words_min_frequency: 1
tgt_vocab_size: 8000
vocab_size_multiple: 1
decoder_start_token: '</s>'
#### Subword
src_subword_model: "shared_gmmt_spm.model"
tgt_subword_model: "shared_gmmt_spm.model"
src_subword_nbest: 1
src_subword_alpha: 0.0
tgt_subword_nbest: 1
tgt_subword_alpha: 0.0
# Corpus opts:
data:
en-gez-gmmt:
path_src: "gmmt/shared_train.txt"
path_tgt: "gmmt/shared_train.txt"
transforms: [sentencepiece, prefix, suffix, filtertoolong]
weight: 10
src_prefix: "</s> eng_Latn"
tgt_prefix: "gez_Ethi"
src_suffix: ""
tgt_suffix: ""
update_vocab: true
save_data: "gmmt"
overwrite: true
onmt_build_vocab -config en_gez.yaml -n_sample -1
I then added the tokens in the new dictionary to the the nllb dictionary like this…
with open('gmmt/dictionary1.txt', 'r') as file:
gmmt_tokens = [line.strip().split()[0] for line in file]
with open('nllb-200/dictionary.txt', 'r') as file:
nllb_tokens = [line.strip().split()[0] for line in file]
added_tokens = set(gmmt_tokens).difference(set(nllb_tokens))
newtokens = nllb_tokens[:-3] + list(added_tokens) + nllb_tokens[-3:]
with open('newdictionary.txt', 'w') as file:
for line in newtokens:
file.write(f"{line} 1 \n")
Here I added the new tokens to the nllb spm model, the script is the same as the one here Finetuning and Curating NLLB-200 with OpenNMT-py. I just added the new language token(gez_Ethi) to the tok_exclusion list.
from unicodedata2 import *
from collections import Counter
from tqdm import tqdm
import sentencepiece as spm
import sentencepiece_model_pb2 as model
tok_exclusion = ['<s>', '<blank>', '</s>', '<unk>', 'gez_Ethi', 'ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn', '<pad1>', '<pad2>', '<pad3>', '<inv>']
newdict2 = []
with open('newdictionary.txt', 'r', encoding='utf-8') as f:
for line in f:
token = line.strip().split()[0]
newdict2.append(token)
serializedStr=open('nllb-200/flores200_sacrebleu_tokenizer_spm.model', 'rb').read()
m=model.ModelProto()
m.ParseFromString(serializedStr)
curdict = []
for i in tqdm(range(len(m.pieces) - 1, 2, -1)):
curdict.append(m.pieces[i].piece)
if m.pieces[i].piece not in newdict2:
hex_string = "".join("{:02x}".format(ord(c)) for c in m.pieces[i].piece)
print("Removing: ", hex_string, " from spm model, not in dict. Index: ", i)
m.pieces.pop(i)
for tok in tqdm(newdict2):
if (tok not in curdict) and (tok not in tok_exclusion):
print("Adding: ", tok, " to spm model")
newtoken = m.SentencePiece()
newtoken.piece = tok
newtoken.score = 0
m.pieces.append(newtoken)
print(len(m.pieces))
with open('flores200_sacrebleu_tokenizer_spm2.model', 'wb') as f:
f.write(m.SerializeToString())
0%| | 228/255997 [00:00<09:45, 436.81it/s]
Removing: 85 from spm model, not in dict. Index: 255860
100%|██████████| 255997/255997 [04:58<00:00, 857.77it/s]
98%|█████████▊| 256024/260926 [04:43<00:00, 21191.42it/s]
Adding: ▁weep to spm model
Adding: ▁ወለያ to spm model
Adding: ልአ to spm model
Adding: ▁ወፋ to spm model
Adding: ፃረ to spm model
Adding: ▁ይቅት to spm model
Adding: baal to spm model
Adding: ▁ጽድቀ to spm model
Adding: ▁calf to spm model
Adding: ▁ወአና to spm model
Adding: ይቴ to spm model
Adding: ▁ሞጻ to spm model
Adding: መፃ to spm model
dding: ▁ዐራ to spm model
Adding: ኒአ to spm model
Adding: ላዕሌ to spm model
Adding: ▁ወዲበ to spm model
Adding: ዐኒ to spm model
Adding: ▁Jezreel to spm model
Adding: ሁኒ to spm model
Adding: ▁Gilgal to spm model
Adding: ሠሥ to spm model
Adding: ▁priests to spm model
Adding: ▁ዘውስተ to spm model
Adding: ጽሖ to spm model
Adding: ▁Cursed to spm model
Adding: ▁ካህን to spm model
Adding: ▁ይጸል to spm model
Adding: ይከ to spm model
Adding: ፸ to spm model
Adding: ▁prophesied to spm model
Adding: ዐር to spm model
Adding: ሞር to spm model
Adding: ልፈ to spm model
Adding: ▁ገብርከ to spm model
Adding: servants to spm model
Adding: ▁bullock to spm model
Adding: ▁በውእቶን to spm model
Adding: ▁ርኢኩ to spm model
Adding: ባአ to spm model
Adding: ሕየ to spm model
Adding: ▁እብል to spm model
Adding: ▁ወባ to spm model
Adding: ▁hired to spm model
Adding: ▁በልቡ to spm model
Adding: ▁መልአኮሙ to spm model
Adding: ▁ለይእቲ to spm model
Adding: aroth to spm model
Adding: ፷ to spm model
Adding: ▁smitten to spm model
Adding: ▁ዘመጽአ to spm model
Adding: ብስ to spm model
Adding: ▁ውእተ to spm model
Adding: ንዋ to spm model
Adding: ▁ዕረጉ to spm model
Adding: ▁haste to spm model
Adding: ▁ርስ to spm model
Adding: ould to spm model
Adding: ▁ወኢምንተ to spm model
Adding: ▁ለአምላክ to spm model
Adding: ክሉ to spm model
Adding: ዓዕ to spm model
Adding: ቦሙ to spm model
Adding: ▁አህጉር to spm model
Adding: ዕለተ to spm model
Adding: ▁መታክ to spm model
Adding: ▁ወኢነ to spm model
Adding: ▁በዕለተ to spm model
Adding: ▁ቀሠ to spm model
Adding: አነ to spm model
Adding: ▁ዕቅ to spm model
Adding: ▁ርስተ to spm model
Adding: ▁አዋልዲ to spm model
Adding: ▁ኀፍረተ to spm model
Adding: ▁በቤተ to spm model
Adding: ▁ዓመተ to spm model
Adding: ▁ገጹ to spm model
.
.
.
I then trained the new model using LoRa weights and fusedadam as optimizer. Here is the config
share_vocab: true
src_vocab: "newdictionary.txt"
src_words_min_frequency: 1
src_vocab_size: 260926
tgt_vocab: "newdictionary.txt"
tgt_words_min_frequency: 1
tgt_vocab_size: 260926
vocab_size_multiple: 1
decoder_start_token: '</s>'
#LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 2
lora_dropout: 0.0
lora_alpha: 1
lora_embedding: false
#### Subword
src_subword_model: "flores200_sacrebleu_tokenizer_spm2.model"
tgt_subword_model: "flores200_sacrebleu_tokenizer_spm2.model"
src_subword_nbest: 1
src_subword_alpha: 0.0
tgt_subword_nbest: 1
tgt_subword_alpha: 0.0
# Corpus opts:
data:
cc-matrix-enzh:
path_src: "gmmt/en_train.txt"
path_tgt: "gmmt/gez_train.txt"
transforms: [sentencepiece, prefix, suffix, filtertoolong]
weight: 10
src_prefix: "</s> eng_Latn"
tgt_prefix: "gez_Ethi"
src_suffix: "</s>"
tgt_suffix: ""
update_vocab: true
train_from: "nllb-200/nllb-200-1.3Bdst-onmt.pt.1"
reset_optim: all
save_data: "finetuned"
save_model: "finetuned/gez_nllb"
log_file: "finetuned/finetuned.log"
keep_checkpoint: 50
save_checkpoint_steps: 100
average_decay: 0.0005
seed: 1234
report_every: 10
train_steps: 20000
valid_steps: 100
# Batching
bucket_size: 262144
num_workers: 2
prefetch_factor: 400
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"
batch_size: 256
valid_batch_size: 256
batch_size_multiple: 1
accum_count: [32, 32, 32]
accum_steps: [0, 15000, 30000]
# Optimization
model_dtype: "fp16"
optim: "fusedadam"
learning_rate: 0.1
warmup_steps: 50
decay_method: "noam"
adam_beta2: 0.98
max_grad_norm: 0
label_smoothing: 0.1
param_init: 0
param_init_glorot: true
normalization: "tokens"
# Model
override_opts: true
encoder_type: transformer
decoder_type: transformer
enc_layers: 24
dec_layers: 24
heads: 16
hidden_size: 1024
word_vec_size: 1024
transformer_ff: 8192
dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.1, 0.1, 0.1]
share_decoder_embeddings: true
share_embeddings: true
position_encoding: true
position_encoding_type: 'SinusoidalConcat'
python3 ../OpenNMT-py/train.py --config finetuned/finetune.yaml
[2023-05-19 12:57:38,706 INFO] Loading checkpoint from nllb-200/nllb-200-1.3Bdst-onmt.pt.1
[2023-05-19 12:57:40,337 WARNING] configured transforms is different from checkpoint: +{'sentencepiece', 'suffix', 'prefix'}
[2023-05-19 12:57:40,337 INFO] Get suffix for cc-matrix-enzh: {'src': '</s>', 'tgt': ''}
[2023-05-19 12:57:40,337 INFO] Get suffix for src infer:
[2023-05-19 12:57:40,337 INFO] Get suffix for tgt infer:
[2023-05-19 12:57:40,337 INFO] Get prefix for cc-matrix-enzh: {'src': '</s> eng_Latn', 'tgt': 'gez_Ethi'}
[2023-05-19 12:57:40,337 INFO] Get prefix for src infer:
[2023-05-19 12:57:40,337 INFO] Get prefix for tgt infer:
[2023-05-19 12:57:40,337 INFO] Get special vocabs from Transforms: {'src': ['</s>', '</s>', 'eng_Latn'], 'tgt': ['gez_Ethi']}.
[2023-05-19 12:57:40,902 INFO] Updating checkpoint vocabulary with new vocabulary
[2023-05-19 12:57:40,903 INFO] Get suffix for cc-matrix-enzh: {'src': '</s>', 'tgt': ''}
[2023-05-19 12:57:40,904 INFO] Get suffix for src infer:
[2023-05-19 12:57:40,905 INFO] Get suffix for tgt infer:
[2023-05-19 12:57:40,906 INFO] Get prefix for cc-matrix-enzh: {'src': '</s> eng_Latn', 'tgt': 'gez_Ethi'}
[2023-05-19 12:57:40,908 INFO] Get prefix for src infer:
[2023-05-19 12:57:40,909 INFO] Get prefix for tgt infer:
[2023-05-19 12:57:40,911 INFO] Get special vocabs from Transforms: {'src': ['</s>', '</s>', 'eng_Latn'], 'tgt': ['gez_Ethi']}.
[2023-05-19 12:57:41,534 INFO] Over-ride model option set to true - use with care
[2023-05-19 12:57:41,534 INFO] Option: config , value: finetuned/finetune.yaml overiding model:
[2023-05-19 12:57:41,534 INFO] Option: data , value: {'cc-matrix-enzh': {'path_src': 'gmmt/en_train.txt', 'path_tgt': 'gmmt/gez_train.txt', 'transforms': ['sentencepiece', 'prefix', 'suffix', 'filtertoolong'], 'weight': 10, 'src_prefix': '</s> eng_Latn', 'tgt_prefix': 'gez_Ethi', 'src_suffix': '</s>', 'tgt_suffix': '', 'path_align': None}} overiding model: {}
[2023-05-19 12:57:41,534 INFO] Option: skip_empty_level , value: warning overiding model: silent
[2023-05-19 12:57:41,534 INFO] Option: save_data , value: finetuned overiding model:
[2023-05-19 12:57:41,534 INFO] Option: src_vocab , value: newdictionary.txt overiding model:
[2023-05-19 12:57:41,534 INFO] Option: tgt_vocab , value: newdictionary.txt overiding model:
[2023-05-19 12:57:41,534 INFO] Option: src_vocab_size , value: 260926 overiding model: 256206
[2023-05-19 12:57:41,534 INFO] Option: tgt_vocab_size , value: 260926 overiding model: 256206
[2023-05-19 12:57:41,534 INFO] Option: src_subword_model , value: flores200_sacrebleu_tokenizer_spm2.model overiding model:
[2023-05-19 12:57:41,534 INFO] Option: tgt_subword_model , value: flores200_sacrebleu_tokenizer_spm2.model overiding model:
[2023-05-19 12:57:41,535 INFO] Option: src_seq_length , value: 192 overiding model: 150
[2023-05-19 12:57:41,535 INFO] Option: tgt_seq_length , value: 192 overiding model: 150
[2023-05-19 12:57:41,535 INFO] Option: update_vocab , value: True overiding model: False
[2023-05-19 12:57:41,535 INFO] Option: add_qkvbias , value: False overiding model: True
[2023-05-19 12:57:41,535 INFO] Option: save_model , value: finetuned/gez_nllb overiding model: nllb
[2023-05-19 12:57:41,535 INFO] Option: save_checkpoint_steps , value: 100 overiding model: 5000
[2023-05-19 12:57:41,535 INFO] Option: train_from , value: nllb-200/nllb-200-1.3Bdst-onmt.pt.1 overiding model:
[2023-05-19 12:57:41,535 INFO] Option: reset_optim , value: all overiding model: none
[2023-05-19 12:57:41,535 INFO] Option: num_workers , value: 2 overiding model: 4
[2023-05-19 12:57:41,535 INFO] Option: batch_size , value: 256 overiding model: 8192
[2023-05-19 12:57:41,535 INFO] Option: accum_count , value: [32, 32, 32] overiding model: [4]
[2023-05-19 12:57:41,535 INFO] Option: accum_steps , value: [0, 15000, 30000] overiding model: [0]
[2023-05-19 12:57:41,535 INFO] Option: valid_steps , value: 100 overiding model: 5000
[2023-05-19 12:57:41,535 INFO] Option: valid_batch_size , value: 256 overiding model: 4096
[2023-05-19 12:57:41,535 INFO] Option: train_steps , value: 20000 overiding model: 100000
[2023-05-19 12:57:41,535 INFO] Option: optim , value: fusedadam overiding model:
[2023-05-19 12:57:41,535 INFO] Option: dropout , value: [0.1, 0.1, 0.1] overiding model: [0.1]
[2023-05-19 12:57:41,536 INFO] Option: attention_dropout , value: [0.1, 0.1, 0.1] overiding model: [0.1]
[2023-05-19 12:57:41,536 INFO] Option: dropout_steps , value: [0, 15000, 30000] overiding model: [0]
[2023-05-19 12:57:41,536 INFO] Option: average_decay , value: 0.0005 overiding model: 0.0
[2023-05-19 12:57:41,536 INFO] Option: learning_rate , value: 0.1 overiding model: 5e-05
[2023-05-19 12:57:41,536 INFO] Option: decay_method , value: noam overiding model: none
[2023-05-19 12:57:41,536 INFO] Option: warmup_steps , value: 50 overiding model: 4000
[2023-05-19 12:57:41,536 INFO] Option: log_file , value: finetuned/finetuned.log overiding model:
[2023-05-19 12:57:41,536 INFO] Option: report_every , value: 10 overiding model: 100
[2023-05-19 12:57:41,536 INFO] Option: _all_transform , value: {'sentencepiece', 'filtertoolong', 'suffix', 'prefix'} overiding model: {'filtertoolong'}
[2023-05-19 12:57:41,536 INFO] Building model...
[2023-05-19 12:57:51,128 INFO] Adding LoRa layers for linear_values
[2023-05-19 12:57:51,924 INFO] Adding LoRa layers for linear_query
[2023-05-19 12:57:52,723 INFO] Adding LoRa layers for linear_keys
[2023-05-19 12:57:53,521 INFO] Adding LoRa layers for final_linear
[2023-05-19 12:58:03,997 INFO] Updating vocabulary embeddings with checkpoint embeddings
[2023-05-19 12:58:04,384 INFO] src: 260921 new tokens
[2023-05-19 12:58:04,830 INFO] tgt: 260921 new tokens
[2023-05-19 12:58:07,084 INFO] NMTModel(
(encoder): TransformerEncoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(260926, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): ModuleList(
(0-23): 24 x TransformerEncoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=False)
(linear_values): Linear(in_features=1024, out_features=1024, bias=False)
(linear_query): Linear(in_features=1024, out_features=1024, bias=False)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=False)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=8192, bias=True)
(w_2): Linear(in_features=8192, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
(decoder): TransformerDecoder(
(embeddings): Embeddings(
(make_embedding): Sequential(
(emb_luts): Elementwise(
(0): Embedding(260926, 1024, padding_idx=1)
)
(pe): PositionalEncoding()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(transformer_layers): ModuleList(
(0-23): 24 x TransformerDecoderLayer(
(self_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=False)
(linear_values): Linear(in_features=1024, out_features=1024, bias=False)
(linear_query): Linear(in_features=1024, out_features=1024, bias=False)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=False)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=1024, out_features=8192, bias=True)
(w_2): Linear(in_features=8192, out_features=1024, bias=True)
(layer_norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(dropout_1): Dropout(p=0.1, inplace=False)
(dropout_2): Dropout(p=0.1, inplace=False)
)
(layer_norm_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(drop): Dropout(p=0.1, inplace=False)
(context_attn): MultiHeadedAttention(
(linear_keys): Linear(in_features=1024, out_features=1024, bias=False)
(linear_values): Linear(in_features=1024, out_features=1024, bias=False)
(linear_query): Linear(in_features=1024, out_features=1024, bias=False)
(softmax): Softmax(dim=-1)
(dropout): Dropout(p=0.1, inplace=False)
(final_linear): Linear(in_features=1024, out_features=1024, bias=False)
)
(layer_norm_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
)
)
(generator): Linear(in_features=1024, out_features=260926, bias=True)
)
[2023-05-19 12:58:07,092 INFO] encoder: 771219456
[2023-05-19 12:58:07,092 INFO] decoder: 605397822
[2023-05-19 12:58:07,092 INFO] * number of parameters: 1376617278
[2023-05-19 12:58:07,092 INFO] * src vocab size = 260926
[2023-05-19 12:58:07,092 INFO] * tgt vocab size = 260926
[2023-05-19 12:58:07,195 INFO] Get suffix for cc-matrix-enzh: {'src': '</s>', 'tgt': ''}
[2023-05-19 12:58:07,195 INFO] Get suffix for src infer:
[2023-05-19 12:58:07,195 INFO] Get suffix for tgt infer:
[2023-05-19 12:58:07,196 INFO] Get prefix for cc-matrix-enzh: {'src': '</s> eng_Latn', 'tgt': 'gez_Ethi'}
[2023-05-19 12:58:07,196 INFO] Get prefix for src infer:
[2023-05-19 12:58:07,196 INFO] Get prefix for tgt infer:
[2023-05-19 12:58:07,274 INFO] Get suffix for cc-matrix-enzh: {'src': '</s>', 'tgt': ''}
[2023-05-19 12:58:07,274 INFO] Get suffix for src infer:
[2023-05-19 12:58:07,274 INFO] Get suffix for tgt infer:
[2023-05-19 12:58:07,274 INFO] Get prefix for cc-matrix-enzh: {'src': '</s> eng_Latn', 'tgt': 'gez_Ethi'}
[2023-05-19 12:58:07,274 INFO] Get prefix for src infer:
[2023-05-19 12:58:07,274 INFO] Get prefix for tgt infer:
[2023-05-19 12:58:07,316 INFO] Starting training on GPU: [0]
[2023-05-19 12:58:07,316 INFO] Start training loop without validation...
[2023-05-19 12:58:07,316 INFO] Scoring with: TransformPipe()
[2023-05-19 13:00:29,289 INFO] Step 10/20000; acc: 83.8; ppl: 38.9; xent: 3.7; lr: 0.00010; sents: 2130; bsz: 229/ 168/ 7; 517/378 tok/s; 142 sec;
[2023-05-19 13:01:39,009 INFO] Step 20/20000; acc: 86.3; ppl: 29.0; xent: 3.4; lr: 0.00019; sents: 1961; bsz: 230/ 167/ 6; 1055/767 tok/s; 212 sec;
[2023-05-19 13:02:48,279 INFO] Step 30/20000; acc: 89.5; ppl: 18.7; xent: 2.9; lr: 0.00027; sents: 1936; bsz: 228/ 166/ 6; 1056/767 tok/s; 281 sec;
[2023-05-19 13:03:57,596 INFO] Step 40/20000; acc: 91.5; ppl: 12.0; xent: 2.5; lr: 0.00036; sents: 2027; bsz: 230/ 169/ 6; 1063/782 tok/s; 350 sec;
[2023-05-19 13:05:06,485 INFO] Step 50/20000; acc: 92.2; ppl: 9.7; xent: 2.3; lr: 0.00044; sents: 2007; bsz: 229/ 167/ 6; 1064/777 tok/s; 419 sec;
[2023-05-19 13:06:15,215 INFO] Step 60/20000; acc: 92.4; ppl: 8.8; xent: 2.2; lr: 0.00040; sents: 1999; bsz: 231/ 167/ 6; 1075/778 tok/s; 488 sec;
I merged the LoRa weights with the base model in this way and tried to infer using the config below.
python3 ../OpenNMT-py/tools/lora_weights.py --action merge --base_model nllb-200/nllb-200-1.3Bdst-onmt.pt --lora_weights finetuned/gez_nllb_step_20000.pt --output geez_nllb_finetuned.pt
transforms: [sentencepiece, prefix, suffix]
# nllb-200 specific prefixing and suffixing
src_prefix: "eng_Latn"
tgt_prefix: "fra_Latn"
tgt_file_prefix: true
src_suffix: "</s>"
tgt_suffix: ""
#### Subword
src_subword_model: "flores200_sacrebleu_tokenizer_spm2.model"
tgt_subword_model: "flores200_sacrebleu_tokenizer_spm2.model"
src_subword_nbest: 1
src_subword_alpha: 0.0
tgt_subword_nbest: 1
tgt_subword_alpha: 0.0
# Model info
model: "geez_nllb_finetuned_1.pt"
# Inference
max_length: 512
gpu: 0
batch_type: tokens
batch_size: 32
fp16:
beam_size: 5
report_time: true
python3 ../OpenNMT-py/translate.py --config finetuned/geez_nllb_inference.yaml -src en_text.src -output gez_hyp.txt
But raised the following error.
Traceback (most recent call last):
File "../OpenNMT-py/translate.py", line 6, in <module>
main()
File "/home/aman/Documents/geeztranslation/OpenNMT-py/onmt/bin/translate.py", line 60, in main
translate(opt)
File "/home/aman/Documents/geeztranslation/OpenNMT-py/onmt/bin/translate.py", line 23, in translate
translator = build_translator(opt, logger=logger,
File "/home/aman/Documents/geeztranslation/OpenNMT-py/onmt/translate/translator.py", line 33, in build_translator
vocabs, model, model_opt = load_test_model(opt)
File "/home/aman/Documents/geeztranslation/OpenNMT-py/onmt/model_builder.py", line 171, in load_test_model
model = build_base_model(model_opt, vocabs, checkpoint)
File "/home/aman/Documents/geeztranslation/OpenNMT-py/onmt/model_builder.py", line 402, in build_base_model
model.load_state_dict(checkpoint['model'],
File "/home/aman/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NMTModel:
Missing key(s) in state_dict: "encoder.transformer.0.self_attn.linear_keys.bias", "encoder.transformer.0.self_attn.linear_values.bias", "encoder.transformer.0.self_attn.linear_query.bias", "encoder.transformer.0.self_attn.final_linear.bias", "encoder.transformer.1.self_attn.linear_keys.bias", "encoder.transformer.1.self_attn.linear_values.bias", "encoder.transformer.1.self_attn.linear_query.bias", "encoder.transformer.1.self_attn.final_linear.bias", "encoder.transformer.2.self_attn.linear_keys.bias", "encoder.transformer.2.self_attn.linear_values.bias", "encoder.transformer.2.self_attn.linear_query.bias", "encoder.transformer.2.self_attn.final_linear.bias", "encoder.transformer.3.self_attn.linear_keys.bias", "encoder.transformer.3.self_attn.linear_values.bias", "encoder.transformer.3.self_attn.linear_query.bias", "encoder.transformer.3.self_attn.final_linear.bias", "encoder.transformer.4.self_attn.linear_keys.bias", "encoder.transformer.4.self_attn.linear_values.bias", "encoder.transformer.4.self_attn.linear_query.bias", "encoder.transformer.4.self_attn.final_linear.bias", "encoder.transformer.5.self_attn.linear_keys.bias", "encoder.transformer.5.self_attn.linear_values.bias", "encoder.transformer.5.self_attn.linear_query.bias", "encoder.transformer.5.self_attn.final_linear.bias", "encoder.transformer.6.self_attn.linear_keys.bias", "encoder.transformer.6.self_attn.linear_values.bias", "encoder.transformer.6.self_attn.linear_query.bias", "encoder.transformer.6.self_attn.final_linear.bias", "encoder.transformer.7.self_attn.linear_keys.bias", "encoder.transformer.7.self_attn.linear_values.bias", "encoder.transformer.7.self_attn.linear_query.bias", "encoder.transformer.7.self_attn.final_linear.bias", "encoder.transformer.8.self_attn.linear_keys.bias", "encoder.transformer.8.self_attn.linear_values.bias", "encoder.transformer.8.self_attn.linear_query.bias", "encoder.transformer.8.self_attn.final_linear.bias", "encoder.transformer.9.self_attn.linear_keys.bias", "encoder.transformer.9.self_attn.linear_values.bias", "encoder.transformer.9.self_attn.linear_query.bias", "encoder.transformer.9.self_attn.final_linear.bias", "encoder.transformer.10.self_attn.linear_keys.bias", "encoder.transformer.10.self_attn.linear_values.bias", "encoder.transformer.10.self_attn.linear_query.bias", "encoder.transformer.10.self_attn.final_linear.bias", "encoder.transformer.11.self_attn.linear_keys.bias", "encoder.transformer.11.self_attn.linear_values.bias", "encoder.transformer.11.self_attn.linear_query.bias", "encoder.transformer.11.self_attn.final_linear.bias", "encoder.transformer.12.self_attn.linear_keys.bias", "encoder.transformer.12.self_attn.linear_values.bias", "encoder.transformer.12.self_attn.linear_query.bias", "encoder.transformer.12.self_attn.final_linear.bias", "encoder.transformer.13.self_attn.linear_keys.bias", "encoder.transformer.13.self_attn.linear_values.bias", "encoder.transformer.13.self_attn.linear_query.bias", "encoder.transformer.13.self_attn.final_linear.bias", "encoder.transformer.14.self_attn.linear_keys.bias", "encoder.transformer.14.self_attn.linear_values.bias", "encoder.transformer.14.self_attn.linear_query.bias", "encoder.transformer.14.self_attn.final_linear.bias", "encoder.transformer.15.self_attn.linear_keys.bias", "encoder.transformer.15.self_attn.linear_values.bias", "encoder.transformer.15.self_attn.linear_query.bias", "encoder.transformer.15.self_attn.final_linear.bias", "encoder.transformer.16.self_attn.linear_keys.bias", "encoder.transformer.16.self_attn.linear_values.bias", "encoder.transformer.16.self_attn.linear_query.bias", "encoder.transformer.16.self_attn.final_linear.bias", "encoder.transformer.17.self_attn.linear_keys.bias", "encoder.transformer.17.self_attn.linear_values.bias", "encoder.transformer.17.self_attn.linear_query.bias", "encoder.transformer.17.self_attn.final_linear.bias", "encoder.transformer.18.self_attn.linear_keys.bias", "encoder.transformer.18.self_attn.linear_values.bias", "encoder.transformer.18.self_attn.linear_query.bias", "encoder.transformer.18.self_attn.final_linear.bias", "encoder.transformer.19.self_attn.linear_keys.bias", "encoder.transformer.19.self_attn.linear_values.bias", "encoder.transformer.19.self_attn.linear_query.bias", "encoder.transformer.19.self_attn.final_linear.bias", "encoder.transformer.20.self_attn.linear_keys.bias", "encoder.transformer.20.self_attn.linear_values.bias", "encoder.transformer.20.self_attn.linear_query.bias", "encoder.transformer.20.self_attn.final_linear.bias", "encoder.transformer.21.self_attn.linear_keys.bias", "encoder.transformer.21.self_attn.linear_values.bias", "encoder.transformer.21.self_attn.linear_query.bias", "encoder.transformer.21.self_attn.final_linear.bias", "encoder.transformer.22.self_attn.linear_keys.bias", "encoder.transformer.22.self_attn.linear_values.bias", "encoder.transformer.22.self_attn.linear_query.bias", "encoder.transformer.22.self_attn.final_linear.bias"
Based on a discussion here Finetuning bigger models with LoRa I fixed it in the following way and the inference run successfully. I had to also reduce the batch_size to 32 because of OOM issue.
import torch
m = torch.load("geez_nllb_finetuned.pt")
m['opt'].add_qkvbias=False
torch.save(m, "geez_nllb_finetuned_1.pt")
But the translation is weird.
⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇
⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇
Please help…