Training a multilingual model with OpenNMT-py

I was able to get M2M-100 working with CTranslate2 and have been trying to train a similar multilingual model from scratch using OpenNMT-py.

What is the best format to use for tokens that tell the language model what the source and target languages are? For M2M-100 I appended the source token to the source text and then called ctranslate2.Translator.translate_batch with target_prefix=[[target_code_token]] * len(tokenized_sentences). Another option for format is to prepend the target code token to the source text like in this tutorial.

M2M-100 format:

__en__William Caxton (c. 1422 – c. 1491) was an English merchant, diplomat and writer.
__fr__William Caxton (c. 1422 – c. 1491) était un marchand, diplomate et écrivain anglais.

Prepend the target code to source text:

__fr__William Caxton (c. 1422 – c. 1491) was an English merchant, diplomat and writer.
William Caxton (c. 1422 – c. 1491) was an English merchant, diplomat and writer.

Prepend the source and target code to source text:

__en__ __fr__William Caxton (c. 1422 – c. 1491) was an English merchant, diplomat and writer.
William Caxton (c. 1422 – c. 1491) was an English merchant, diplomat and writer.

Is there any sort of industry standard for this? I think I prefer prepending the source and target code tokens to the source text but I also want to maximize compatibility with models trained by other people.

Additionally, how does the target_prefix parameter work in CTranslate2? My understanding is that it prepends the provided prefixes to the target text while it is being decoded. The target_prefix parameter works with M2M-100 but doesn’t seem to work with my OpenNMT-py model.

I created a multilingual dataset with 95949184 lines of data from Opus and formatted it in the M2M-100 format. I then trained a model with OpenNMT-py like I would for a individual language pair. I ran the model with CTranslate2 prepending the source code token to the source text and using the target_prefix parameter for the target language. I get completely incorrect output and the target_prefix doesn’t seem to affect the translation at all.

$ argos-translate -f en -t de "Cheese"
es ies.
$ argos-translate -f en -t fr "Cheese"
es ies.
$ argos-translate -f en -t es "Cheese"
es ies.
$ argos-translate -f en -t es "I'm flying to Miami next week."
Miami i Miami.

Will most pretrained models just need custom logic? I know this is often true when running models on Huggingface, different language models from different companies often need a custom tokenizer or other logic.

There are no industry format.

Most papers reflected the “old way” using the target language code in the source sentence (either prefix or suffix).
The logic was “ask the model to learn to which language this sentence is translated”.

It works fine.

However the last paper (NLLB from Meta) is doing things differently.

They say they want the model to have some zero-shot capability.
Hence they use the source language token in the source sentence, so that the “ENCODER” will take into account the information of the language in the encoded embeddings of the sentence.

They also prefix the target sentence with the target language code so that when using the target prefix (OpenNMT-py and Ctranslate2 can do this) it will force the forst token to be the language code and then the decoder will decode the following tokens.

Hope this helps.

FYI: there are two posts on NLLB-200, last one on finetuning which might be a shorter route to what you are trying to accomplish.

2 Likes

Thanks, I’m going to look interesting fine tuning. All of the multilingual models I’ve tried to train from scratch have generated total gibberish.

However, have you trained like this? If you consider the language token as only one token at inference, you should add it as a special token to SentencePiece with the option --user_defined_symbols. Otherwise, it will be considered multiple tokens.

One other factor of a multilingual model is the size of data. So if it is huge, consider using TransformerBig or using an even deeper and/or wider network.

All the best,
Yasmin

1 Like

NLLB-200 uses “Sparsely Gated Mixture of Experts”…

Sparse Mixture-of-Experts models are a way to drastically increase model capacity without the need for a proportional amount of computing. The recently released NLLB200 is an example of such a model.

1 Like

The MoE model is the huge one (54 Billion parameters), this is out of the scope of the OpenNMT-py project. We converted only the dense models (from 600M to 3.3B parameters). The 1.3B and 3.3B work already very well.

1 Like

I’ve trained multiple multilingual models using OpenNMT-py with significant success. I employ a “src_token tgt_token SENTENCE” format and a tgt_token in the decoded sentence, as in NLLB like vince said.

I use a transformer “medium” - of sorts (i’ve fooled around with some settings in some models to increase throughput) - which results in better translations than one-way models (the ones I’ve trained anyway).

I haven’t had to implement any custom logic, just simply added the src and tgt tokens to the sentencepiece special vocab (each lang has a src and tgt token).

1 Like

Thanks for the suggestions.

Are there any example configurations for using a larger Transformer model with OpenNMT-py? Searching “TransformerBig” on this forum I only see OpenNMT-tf and I don’t see anything in the documentation/source for OpenNMT-py. I’ve experimented a little with just increasing the enc_layers, dec_layers, and heads parameter in config.yml.

1 Like

I haven’t experimented with bigger models personally, my multilingual models only comes out to 110m parameters (using a 1 layer decoder snd 12 layer encoder w/ 16 heads). Speed is a big thing for me as the models are hosted on a resource-limited server.

A basic TransformerBig would just be increasing the hidden_size, word embedding dimension, and transformer feed forward (prob to 1024, 1024, 4096 respectively). And a reasonable change to layers/heads (for me anyway) would prob be 8/16 for the vanilla TransformerBig.

There’s no one standard really; the “Attention Is All You Need” paper uses a 6layer/16head setup with the same hidden size, embedding dim and transformer ff as above. Also, increasing dropout to 0.3 is probably a good idea too.

2 Likes

If I work with OpenNMT-tf I just use the pre-config TransformerBigRelative. If I use OpenNMT-py, I copy these values to the config file.

The Transformer Base config can be found here. However, note that you have to set position_encoding: 'false' if you use maximum_relative_position.

decoder_type: transformer
encoder_type: transformer
word_vec_size: 1024
rnn_size: 1024
layers: 6
transformer_ff: 4096
heads: 16
max_relative_positions: 16 # or 20
1 Like

Thanks for the suggestion.

I tried these values but it breaks something. Do I need to also set rnn_size to 1024?

I used a config.yml similar to the documentation with these changes:

-enc_layers: 6
+enc_layers: 8
-dec_layers: 6
+dec_layers: 16
-hidden_size: 512
+hidden_size: 1024
-word_vec_size: 512
+word_vec_size: 1024
-transformer_ff: 2048
+transformer_ff: 4096
-dropout: [0.1]
+dropout: [0.3]

With OpenNMT-py v2 I get the error: RuntimeError: Given normalized_shape=[512], expected input with shape [*, 512], but got input of size[282, 13, 1024].

Traceback (most recent call last):
  File "/home/argosopentech/OpenNMT-py/onmt/trainer.py", line 426, in _gradient_accumulation                      
    outputs, attns = self.model(                                         
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)                        
  File "/home/argosopentech/OpenNMT-py/onmt/models/model.py", line 63, in forward                                 
    enc_state, memory_bank, lengths = self.encoder(src, lengths)
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/argosopentech/OpenNMT-py/onmt/encoders/transformer.py", line 138, in forward                        
    out = layer(out, mask)                                               
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/argosopentech/OpenNMT-py/onmt/encoders/transformer.py", line 53, in forward                         
    input_norm = self.layer_norm(inputs)
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl   
    return forward_call(*args, **kwargs)
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 190, in forward
    return F.layer_norm(                                                                                                                          
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/nn/functional.py", line 2515, in layer_norm      
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Given normalized_shape=[512], expected input with shape [*, 512], but got input of size[282, 13, 1024]

If you’re using OpenNMT-py v2
hidden_size should be rnn_size as it was renamed in the onmt-py v3 release to hidden_size.

Also just wanted to clarify from my previous post (just incase I said things weirdly :sweat_smile:) that I meant a setup of below [i use onmt v3]:

heads: 16
enc_layers: 8
dec_layers: 8
hidden_size: 1024
word_vec_size: 1024
dropout: [0.3]
transformer_ff: 4096
2 Likes

This config works for me but I think causes a OOM error with 2x RTX4090 GPUs.

Setting rnn_size: 1024 worked with OpenNMT-py v2 to fix the expected input with shape [*, 512], but got input of size[282, 13, 1024] error.

Traceback (most recent call last):
  File "/home/argosopentech/env/bin/onmt_train", line 33, in <module>
    sys.exit(load_entry_point('OpenNMT-py', 'console_scripts', 'onmt_train')())
  File "/home/argosopentech/OpenNMT-py/onmt/bin/train.py", line 172, in main
    train(opt)
  File "/home/argosopentech/OpenNMT-py/onmt/bin/train.py", line 157, in train
    train_process(opt, device_id=0)
  File "/home/argosopentech/OpenNMT-py/onmt/train_single.py", line 112, in main
    trainer.train(
  File "/home/argosopentech/OpenNMT-py/onmt/trainer.py", line 277, in train
    self._gradient_accumulation(
  File "/home/argosopentech/OpenNMT-py/onmt/trainer.py", line 506, in _gradient_accumulation
    self.optim.step()
  File "/home/argosopentech/OpenNMT-py/onmt/utils/optimizers.py", line 367, in step
    self._optimizer.step()
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 33, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/optim/adam.py", line 141, in step
    adam(
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/optim/adam.py", line 281, in adam
    func(params,
  File "/home/argosopentech/env/lib/python3.10/site-packages/torch/optim/adam.py", line 505, in _multi_tensor_adam
    exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 196.00 MiB (GPU 0; 23.65 GiB total capacity; 14.90 GiB already allocated; 47.69 MiB free; 15.85 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I’m going to experiment with setting export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' and decreasing the batch size.

Edit: export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' works

I’d suggest using model_dtype: "fp16", [linked config.yml uses fp32] that may save enough memory and will speed up training a bit.
Strange that you have the OOM errors though, I haven’t had any issues with any models with a 3090, so it’s strange that 2 4090’s don’t function

1 Like

Are you really able to use 2 x 4090s in parallel training and get good performance? Last I checked p2p was broken with multiple 4090s and this lead either to crashes or to abysmal performance when using 2 x 4090s. This is a known issue that I have indeed experienced. NVidia are not very explicit about if/when they’ll fix p2p for 4090s. For more info, see here: Standard nVidia CUDA tests fail with dual RTX 4090 Linux box - #43 by abchauhan - Linux - NVIDIA Developer Forums

2 Likes