Convert to Keras Model for CoreMlTools

I would like to convert opennmt-tf model to CoreML model. Here is an example for Huggingface transformers. I am trying to find out a good way to wrap the Transformer model to a Keras Model with correct inputs. It would be great if we can wrap this as a Keras Model and use them on-device. Can anyone help on this? Thanks

[huggingface transformers to coreml]

from transformers import DistilBertTokenizer, TFDistilBertForMaskedLM
import coremltools as ct
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
distilbert_model = TFDistilBertForMaskedLM.from_pretrained('distilbert-base-cased')
max_seq_length = 10
input_shape = (1, max_seq_length) #(batch_size, maximum_sequence_length)
input_layer = tf.keras.layers.Input(shape=input_shape[1:], dtype=tf.int32, name='input')
prediction_model = distilbert_model(input_layer)
tf_model = tf.keras.models.Model(inputs=input_layer, outputs=prediction_model)
mlmodel = ct.convert(tf_model)

[opennmt-tf to coreml]

import tensorflow as tf
class WrapTransformerModel(tf.keras.models.Model):

   def __init__(self, transformer):
        super(WrapTransformerModel, self).__init__()
        self.transformer = transformer

    def call(self, features):
        res = self.transformer.call(features, training=False)
        return res
# https://opennmt-models.s3.amazonaws.com/averaged-ende-export500k-v2-beam1.tar.gz
model = tf.keras.models.load_model('./averaged-ende-export500k-v2-beam1')
m = WrapTransformerModel(model)
token_input_layer = tf.keras.layers.Input(shape=10, dtype=tf.string, name='tokens')
length_input_layer = tf.keras.layers.Input(shape=1, dtype=tf.int32, name='length')
outputs = m({"tokens": token_input_layer, "length": length_input_layer})
tf_model = tf.keras.models.Model(inputs=[token_input_layer, length_input_layer], outputs=outputs)

Hi,

Looks like I can create a Keras Model this way:

import tensorflow as tf

class Translator(tf.keras.layers.Layer):
    def __init__(self, translate_fn):
        super().__init__()
        self._translate_fn = translate_fn

    def call(self, tokens, length):
        return self._translate_fn(tokens=tokens, length=length)

saved_model_path = "/home/klein/dev/OpenNMT-tf/models/averaged-ende-export500k-v2-beam1"
saved_model = tf.saved_model.load(saved_model_path)
translator = Translator(saved_model.signatures["serving_default"])

tokens = tf.keras.Input(shape=(4,), dtype=tf.string, name="tokens")
length = tf.keras.Input(shape=(), dtype=tf.int32, name="length")
outputs = translator(tokens, length)

model = tf.keras.models.Model(inputs=[tokens, length], outputs=outputs)

test_tokens = tf.constant([["▁H", "ello", "▁world", "!"]])
test_length = tf.constant([4], dtype=tf.int32)
print(model.predict_on_batch([test_tokens, test_length]))

Can you try?

However, I’m not sure CoreML supports all operators commonly used in seq2seq models such as vocabulary lookup, decoding loop, variable shapes, etc.

Thanks a lot. This is very helpful. Starting from the look up table, it is not supported in TF2 well with the following error.

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

But someone had a solution below to create a customized StaticHashTable. https://stackoverflow.com/questions/59962509/valueerror-cannot-convert-a-tensor-of-dtype-resource-to-a-numpy-array

I am not sure how we can pull out this look up and use converted ids to as the input.

Another question is about the model. When I tried to understand where tokens become ids for the graph, I lost context in this line (sequence_to_sequence.py)

Is the source_inputs already int type? I only saw WordEmbedder will have tokens_to_ids in the function make_feature. Does it mean Transformer model is using WordEmbedder as the source_inputter instead of TextInputter?

Here source_inputs is already the word embeddings. The string lookup is done in make_features which is either called as a dataset transformation (during training, evaluation) or in the serving function:

Correct, the base Transformer model is using WordEmbedder.


I see that you started to explore the code so maybe you’ll find a way to change the inputs/outputs. For reference, here is how you can call the conversion for the pretrained checkpoint (CoreML conveniently accepts a TensorFlow function so you don’t need to build a Keras model):

import os
import tensorflow as tf
import opennmt
import coremltools

checkpoint_dir = "/home/klein/dev/OpenNMT-tf/models/averaged-ende-ckpt500k-v2"
vocabulary = os.path.join(checkpoint_dir, "wmtende.vocab")

model = opennmt.models.TransformerBase()
model.initialize({
    "source_vocabulary": vocabulary,
    "target_vocabulary": vocabulary,
})

checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

serve_function = model.serve_function().get_concrete_function()
mlmodel = coremltools.convert([serve_function], source="tensorflow")

Thanks a lot. In order to skip the tokens_to_ids which we can do it outside the model, do you suggest writing a new tf function? But the problem here is that call function already expects features with tokens+length.

Do we need to create another similar function to call in the sequence_to_sequence.py and expect features with ids?

A minimal diff would look like this:

diff --git a/opennmt/inputters/text_inputter.py b/opennmt/inputters/text_inputter.py
index 737e738..d7ef1b8 100644
--- a/opennmt/inputters/text_inputter.py
+++ b/opennmt/inputters/text_inputter.py
@@ -252,7 +252,7 @@ class TextInputter(Inputter):
   def initialize(self, data_config, asset_prefix=""):
     self.vocabulary_file = _get_field(
         data_config, "vocabulary", prefix=asset_prefix, required=True)
-    self.vocabulary_size, self.tokens_to_ids, self.ids_to_tokens = _create_vocabulary_tables(
+    self.vocabulary_size, _, _ = _create_vocabulary_tables(
         self.vocabulary_file,
         self.num_oov_buckets,
         as_asset=data_config.get("export_vocabulary_assets", True))
@@ -415,6 +415,11 @@ class WordEmbedder(TextInputter):
       features["length"] -= 1
     return features
 
+  def input_signature(self):
+    signature = super().input_signature()
+    signature["ids"] = tf.TensorSpec([None, None], dtype=tf.int64)
+    return signature
+
   def build(self, input_shape):
     if self.embedding_file:
       pretrained = load_pretrained_embeddings(
diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py
index be188b8..ebf71a2 100644
--- a/opennmt/models/sequence_to_sequence.py
+++ b/opennmt/models/sequence_to_sequence.py
@@ -247,7 +247,7 @@ class SequenceToSequence(model.SequenceGenerator):
         sampler=decoding.Sampler.from_params(params),
         maximum_iterations=params.get("maximum_decoding_length", 250),
         minimum_iterations=params.get("minimum_decoding_length", 0))
-    target_tokens = self.labels_inputter.ids_to_tokens.lookup(tf.cast(sampled_ids, tf.int64))
+    target_tokens = tf.cast(sampled_ids, tf.int64)
 
     # Maybe replace unknown targets by the source tokens with the highest attention weight.
     if params.get("replace_unknown_target", False):

which adds the ids as an input of the model. But then there is another error on conversion.

If you mean this error below, it might be related to the following TF issue.

 while name_to_node[node_name].op == "Identity":
KeyError: 'self_attention_decoder_self_attention_decoder_layer_transformer_layer_wrapper_13_multi_head_attention_7_cond_input_1_0'

I also tried greedy search by setting beam width as 1 and it showed a similar issue.

model.initialize({
    "source_vocabulary": vocabulary,
    "target_vocabulary": vocabulary,
    "beam_width": 1
})

Actually greedy decoding is the default when no parameters are defined.

So it’s not clear to me whether we should change something on our side or wait for TF & CoreML to improve control flow support. It seems the conversion is going further with TensorFlow 2.3.0 RC0.

Control flow support seems already resolved. And here is a document for Switch and Merge.

Features: https://www.tensorflow.org/api_docs/python/tf/compat/v1/enable_control_flow_v2

But when I added this tf.compat.v1.enable_control_flow_v2(), it still had the error.

while name_to_node[node_name].op == "Identity":
KeyError: 'self_attention_decoder_self_attention_decoder_layer_transformer_layer_wrapper_13_multi_head_attention_7_cond_input_1_0'

Control flow V2 is enabled by default in TensorFlow 2.

I went further to see what can be done. But the latest error is in the position encoding layer. The converter struggles to resolve the shape and thinks there is a mismatch.

Here is the diff so far:

diff --git a/opennmt/inputters/text_inputter.py b/opennmt/inputters/text_inputter.py
index 737e738..b9c3abd 100644
--- a/opennmt/inputters/text_inputter.py
+++ b/opennmt/inputters/text_inputter.py
@@ -252,7 +252,7 @@ class TextInputter(Inputter):
   def initialize(self, data_config, asset_prefix=""):
     self.vocabulary_file = _get_field(
         data_config, "vocabulary", prefix=asset_prefix, required=True)
-    self.vocabulary_size, self.tokens_to_ids, self.ids_to_tokens = _create_vocabulary_tables(
+    self.vocabulary_size, _, _ = _create_vocabulary_tables(
         self.vocabulary_file,
         self.num_oov_buckets,
         as_asset=data_config.get("export_vocabulary_assets", True))
@@ -415,6 +415,11 @@ class WordEmbedder(TextInputter):
       features["length"] -= 1
     return features
 
+  def input_signature(self):
+    signature = super().input_signature()
+    signature["ids"] = tf.TensorSpec([None, None], dtype=tf.int64)
+    return signature
+
   def build(self, input_shape):
     if self.embedding_file:
       pretrained = load_pretrained_embeddings(
@@ -435,7 +440,8 @@ class WordEmbedder(TextInputter):
     super(WordEmbedder, self).build(input_shape)
 
   def call(self, features, training=None):
-    outputs = tf.nn.embedding_lookup(self.embedding, features["ids"])
+    # The converter complains about the type of the indices, so help it.
+    outputs = tf.nn.embedding_lookup(self.embedding, tf.cast(features["ids"], tf.int64))
     outputs = common.dropout(outputs, self.dropout, training=training)
     return outputs
 
diff --git a/opennmt/layers/common.py b/opennmt/layers/common.py
index bffeae7..89df094 100644
--- a/opennmt/layers/common.py
+++ b/opennmt/layers/common.py
@@ -80,6 +80,10 @@ class Dense(tf.keras.layers.Dense):
 class LayerNorm(tf.keras.layers.LayerNormalization):
   """Layer normalization."""
 
+  def __init__(self, *args, **kwargs):
+    # Reduce epsilon to disable fused BatchNorm that raised a conversion error.
+    super().__init__(*args, epsilon=1e-6, **kwargs)
+
   def map_v1_weights(self, weights):
     return [
         (self.beta, weights["beta"]),
diff --git a/opennmt/layers/position.py b/opennmt/layers/position.py
index 8b3c83a..83344f3 100644
--- a/opennmt/layers/position.py
+++ b/opennmt/layers/position.py
@@ -35,12 +36,16 @@ class PositionEncoder(tf.keras.layers.Layer):
     Returns:
       A ``tf.Tensor`` whose shape depends on the configured ``reducer``.
     """
-    batch_size = tf.shape(inputs)[0]
-    timesteps = tf.shape(inputs)[1]
+    input_shape = tf.shape(inputs)
+    batch_size = input_shape[0]
+    timesteps = input_shape[1]
     input_dim = inputs.shape[-1]
     positions = tf.range(timesteps) + 1 if position is None else [position]
     position_encoding = self._encode([positions], input_dim)
     position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
+
+    # The converter thinks the shapes mismatch between position_encoding and inputs.
+    position_encoding = tf.reshape(position_encoding, input_shape)
     return self.reducer([inputs, position_encoding])
 
   @abc.abstractmethod
@@ -99,7 +104,8 @@ class SinusoidalPositionEncoder(PositionEncoder):
 
     log_timescale_increment = math.log(10000) / (depth / 2 - 1)
     inv_timescales = tf.exp(tf.range(depth / 2, dtype=tf.float32) * -log_timescale_increment)
-    inv_timescales = tf.reshape(tf.tile(inv_timescales, [batch_size]), [batch_size, -1])
+    # Keep static depth shape.
+    inv_timescales = tf.reshape(tf.tile(inv_timescales, [batch_size]), [batch_size, depth // 2])
     scaled_time = tf.expand_dims(positions, -1) * tf.expand_dims(inv_timescales, 1)
     encoding = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
     return tf.cast(encoding, self.dtype)
diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py
index d20bfb4..63d6953 100644
--- a/opennmt/layers/transformer.py
+++ b/opennmt/layers/transformer.py
@@ -256,13 +256,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
         keys = tf.concat([cache[0], keys], axis=2)
         values = tf.concat([cache[1], values], axis=2)
     else:
-      if cache:
-        keys, values = tf.cond(
-            tf.equal(tf.shape(cache[0])[2], 0),
-            true_fn=lambda: _compute_kv(memory),
-            false_fn=lambda: cache)
-      else:
-        keys, values = _compute_kv(memory)
+      # Ignore cache and always recompute projections to avoid using problematic tf.cond
+      keys, values = _compute_kv(memory)
 
     if self.maximum_relative_position is not None:
       if memory is not None:
diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py
index be188b8..ebf71a2 100644
--- a/opennmt/models/sequence_to_sequence.py
+++ b/opennmt/models/sequence_to_sequence.py
@@ -247,7 +247,7 @@ class SequenceToSequence(model.SequenceGenerator):
         sampler=decoding.Sampler.from_params(params),
         maximum_iterations=params.get("maximum_decoding_length", 250),
         minimum_iterations=params.get("minimum_decoding_length", 0))
-    target_tokens = self.labels_inputter.ids_to_tokens.lookup(tf.cast(sampled_ids, tf.int64))
+    target_tokens = tf.cast(sampled_ids, tf.int64)
 
     # Maybe replace unknown targets by the source tokens with the highest attention weight.
     if params.get("replace_unknown_target", False):
diff --git a/opennmt/utils/decoding.py b/opennmt/utils/decoding.py
index 80bfb30..feaf8f8 100644
--- a/opennmt/utils/decoding.py
+++ b/opennmt/utils/decoding.py
@@ -184,7 +184,8 @@ class GreedySearch(DecodingStrategy):
   """A basic greedy search strategy."""
 
   def _initialize(self, batch_size, start_ids, attention_size=None):
-    finished = tf.zeros([batch_size], dtype=tf.bool)
+    # tf.zeros(..., dtype=tf.bool) seems to be unsupported.
+    finished = tf.zeros([batch_size], dtype=tf.int32) == 0
     initial_log_probs = tf.zeros([batch_size], dtype=tf.float32)
     return start_ids, finished, initial_log_probs, []

Thanks a lot for the help. Can I ask what the current error is now? I could still see the following error with the above change. Thanks

TypeError: Input value has type <class 'coremltools.converters.mil.mil.types.type_bool.bool'> not compatible with expected type IntOrFloatInputType

New errors keep coming so it’s seems to me the converter can’t handle graphs with slightly complex control flow.

I think it could succeed with a “lite” version of the decoding loop that uses less features.

Do you suggest we modify the decoding step for a simpler one? For example, use greedy search and remove unnecessary if else condition there.

Yes, basically a decoding loop that ignores the scores, the attention, EOS penalty, etc.

Hi Guoli, Did you find a solution to this? I am just starting on the journey to convert a SavedModel to TensorFlow LIte model and came up against this inability to convert a tensor of dtype to a numpy array. Is it a worthwhile effort as a means of doing inference on an Android device? Thanks, Terence

It is now possible to export selected models to TensorFlow Lite directly from OpenNMT-tf with the tflite export format. However, it only supports RNN-based sequence to sequence models at the moment.

1 Like