Sharing word embeddings, the wrong way—curious about what happened

Hi,

I defined a multi-source transformer model and tried to achieve word embedding weights shared by all sources and the target by actually reusing the WordEmbedder, edited from the MultiSourceTransformer example:

class MultiSourceTransformer(onmt.models.Transformer):
    def __init__(self):
        # Don't use this, it doesn't work
        embedder = onmt.inputters.WordEmbedder(embedding_size=512)
        super().__init__(
            source_inputter=onmt.inputters.ParallelInputter(
                [embedder]*3
            ),
            target_inputter=embedder,
            num_layers=6,
            num_units=768,
            num_heads=8,
            ffn_inner_dim=1024,
            dropout=0.1,
            attention_dropout=0.1,
            ffn_dropout=0.1,
            share_encoders=True,
        )

model = MultiSourceTransformer

Well, after training a model, I can see that something went wrong. The training process seems to have worked, but after loading the model it produces garbage.

Now, I found this another post which describes a presumably working way of sharing embeddings:

I don’t understand enough about the internals of OpenNMT to figure out what actually happened with my broken approach, and it would be nice to know :slight_smile:

Hi,

The source and target inputters are configured differently by the model so it is not possible to reuse the same WordEmbedder instance.

However, your model makes sense from a design point of view, so we should probably look into supporting this type of definition. Thanks!


The following configuration should currently work to share all input embeddings:

class MultiSourceTransformer(onmt.models.Transformer):
    def __init__(self):
        super().__init__(
            source_inputter=onmt.inputters.ParallelInputter(
                [
                    onmt.inputters.WordEmbedder(embedding_size=512),
                    onmt.inputters.WordEmbedder(embedding_size=512),
                    onmt.inputters.WordEmbedder(embedding_size=512),
                ]
            ),
            target_inputter=onmt.inputters.WordEmbedder(embedding_size=512),
            num_layers=6,
            num_units=768,
            num_heads=8,
            ffn_inner_dim=1024,
            dropout=0.1,
            attention_dropout=0.1,
            ffn_dropout=0.1,
            share_encoders=True,
            share_embeddings=onmt.models.EmbeddingsSharingLevel.SOURCE_TARGET_INPUT,
        )

Yeah, I got something like that to work. Thanks for the explanation! Perhaps at least checking for this kind of model and bailing out would be (mildly) useful to prevent surprises when people think it might work and misconfigure it like me, even in the absence of actually supporting it. Of course even better if it just worked. (And I do appreciate that Python provides essentially infinitely many ways to use it wrong, and it’s impossible to anticipate what kinds of definitions people come up with :slight_smile:

1 Like