Custom Transformer model (auto_config)

Hello,

I’m trying to create my own TransformerSmall. I followed the instruction from this post:
https://forum.opennmt.net/t/how-to-change-model-dimension/4072/2

but I also wanted to include the config… so I took for granted that I could use the same code as in the github repo:

opennmt/models/catalog.py

So my resuts is this:

import opennmt

class MyTransformerSmall(opennmt.models.Transformer):
    def __init__(self):
        super().__init__(
            source_inputter=opennmt.inputters.WordEmbedder(embedding_size=512),
            target_inputter=opennmt.inputters.WordEmbedder(embedding_size=512),
            num_layers=5,
            num_units=512,
            num_heads=2,
            ffn_inner_dim=512,
            dropout=0.3,
            attention_dropout=0.3,
            ffn_dropout=0.1)
            
    def auto_config(self, num_replicas=1):
        config = super().auto_config(num_replicas=num_replicas)
        return misc.merge_dict(
            config,
            {
                "params": {
                    "decay_type": "NoamDecay",
                    "decay_params": {
                        "model_dim": 512,
                        "warmup_steps": 4000,
                    "label_smoothing" : 0.6,
                    },
                },
            },
        )

model = MyTransformerSmall

but i’m getting this error when I try to run it…

  File "/usr/local/bin/onmt-main", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.7/dist-packages/opennmt/bin/main.py", line 326, in main
    hvd=hvd,
  File "/usr/local/lib/python3.7/dist-packages/opennmt/runner.py", line 199, in train
    training=True, num_replicas=num_replicas, num_devices=num_devices
  File "/usr/local/lib/python3.7/dist-packages/opennmt/runner.py", line 109, in _finalize_config
    model_config = self._model.auto_config(num_replicas=num_replicas)
  File "/content/gdrive/MyDrive/VGR/TransformerSmall.py", line 18, in auto_config
    return misc.merge_dict(
NameError: name 'misc' is not defined

As a second question I would like to change the decay_type to square root, but i’m not sure what is the name for it in Opennmt?
https://opennmt.net/OpenNMT-tf/package/opennmt.schedules.RsqrtDecay.html

based on this link I shoud just put:

"decay_type": "RsqrtDecay"

You should import this module:

from opennmt.utils import misc
1 Like