Control flow V2 is enabled by default in TensorFlow 2.
I went further to see what can be done. But the latest error is in the position encoding layer. The converter struggles to resolve the shape and thinks there is a mismatch.
Here is the diff so far:
diff --git a/opennmt/inputters/text_inputter.py b/opennmt/inputters/text_inputter.py
index 737e738..b9c3abd 100644
--- a/opennmt/inputters/text_inputter.py
+++ b/opennmt/inputters/text_inputter.py
@@ -252,7 +252,7 @@ class TextInputter(Inputter):
def initialize(self, data_config, asset_prefix=""):
self.vocabulary_file = _get_field(
data_config, "vocabulary", prefix=asset_prefix, required=True)
- self.vocabulary_size, self.tokens_to_ids, self.ids_to_tokens = _create_vocabulary_tables(
+ self.vocabulary_size, _, _ = _create_vocabulary_tables(
self.vocabulary_file,
self.num_oov_buckets,
as_asset=data_config.get("export_vocabulary_assets", True))
@@ -415,6 +415,11 @@ class WordEmbedder(TextInputter):
features["length"] -= 1
return features
+ def input_signature(self):
+ signature = super().input_signature()
+ signature["ids"] = tf.TensorSpec([None, None], dtype=tf.int64)
+ return signature
+
def build(self, input_shape):
if self.embedding_file:
pretrained = load_pretrained_embeddings(
@@ -435,7 +440,8 @@ class WordEmbedder(TextInputter):
super(WordEmbedder, self).build(input_shape)
def call(self, features, training=None):
- outputs = tf.nn.embedding_lookup(self.embedding, features["ids"])
+ # The converter complains about the type of the indices, so help it.
+ outputs = tf.nn.embedding_lookup(self.embedding, tf.cast(features["ids"], tf.int64))
outputs = common.dropout(outputs, self.dropout, training=training)
return outputs
diff --git a/opennmt/layers/common.py b/opennmt/layers/common.py
index bffeae7..89df094 100644
--- a/opennmt/layers/common.py
+++ b/opennmt/layers/common.py
@@ -80,6 +80,10 @@ class Dense(tf.keras.layers.Dense):
class LayerNorm(tf.keras.layers.LayerNormalization):
"""Layer normalization."""
+ def __init__(self, *args, **kwargs):
+ # Reduce epsilon to disable fused BatchNorm that raised a conversion error.
+ super().__init__(*args, epsilon=1e-6, **kwargs)
+
def map_v1_weights(self, weights):
return [
(self.beta, weights["beta"]),
diff --git a/opennmt/layers/position.py b/opennmt/layers/position.py
index 8b3c83a..83344f3 100644
--- a/opennmt/layers/position.py
+++ b/opennmt/layers/position.py
@@ -35,12 +36,16 @@ class PositionEncoder(tf.keras.layers.Layer):
Returns:
A ``tf.Tensor`` whose shape depends on the configured ``reducer``.
"""
- batch_size = tf.shape(inputs)[0]
- timesteps = tf.shape(inputs)[1]
+ input_shape = tf.shape(inputs)
+ batch_size = input_shape[0]
+ timesteps = input_shape[1]
input_dim = inputs.shape[-1]
positions = tf.range(timesteps) + 1 if position is None else [position]
position_encoding = self._encode([positions], input_dim)
position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
+
+ # The converter thinks the shapes mismatch between position_encoding and inputs.
+ position_encoding = tf.reshape(position_encoding, input_shape)
return self.reducer([inputs, position_encoding])
@abc.abstractmethod
@@ -99,7 +104,8 @@ class SinusoidalPositionEncoder(PositionEncoder):
log_timescale_increment = math.log(10000) / (depth / 2 - 1)
inv_timescales = tf.exp(tf.range(depth / 2, dtype=tf.float32) * -log_timescale_increment)
- inv_timescales = tf.reshape(tf.tile(inv_timescales, [batch_size]), [batch_size, -1])
+ # Keep static depth shape.
+ inv_timescales = tf.reshape(tf.tile(inv_timescales, [batch_size]), [batch_size, depth // 2])
scaled_time = tf.expand_dims(positions, -1) * tf.expand_dims(inv_timescales, 1)
encoding = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
return tf.cast(encoding, self.dtype)
diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py
index d20bfb4..63d6953 100644
--- a/opennmt/layers/transformer.py
+++ b/opennmt/layers/transformer.py
@@ -256,13 +256,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
keys = tf.concat([cache[0], keys], axis=2)
values = tf.concat([cache[1], values], axis=2)
else:
- if cache:
- keys, values = tf.cond(
- tf.equal(tf.shape(cache[0])[2], 0),
- true_fn=lambda: _compute_kv(memory),
- false_fn=lambda: cache)
- else:
- keys, values = _compute_kv(memory)
+ # Ignore cache and always recompute projections to avoid using problematic tf.cond
+ keys, values = _compute_kv(memory)
if self.maximum_relative_position is not None:
if memory is not None:
diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py
index be188b8..ebf71a2 100644
--- a/opennmt/models/sequence_to_sequence.py
+++ b/opennmt/models/sequence_to_sequence.py
@@ -247,7 +247,7 @@ class SequenceToSequence(model.SequenceGenerator):
sampler=decoding.Sampler.from_params(params),
maximum_iterations=params.get("maximum_decoding_length", 250),
minimum_iterations=params.get("minimum_decoding_length", 0))
- target_tokens = self.labels_inputter.ids_to_tokens.lookup(tf.cast(sampled_ids, tf.int64))
+ target_tokens = tf.cast(sampled_ids, tf.int64)
# Maybe replace unknown targets by the source tokens with the highest attention weight.
if params.get("replace_unknown_target", False):
diff --git a/opennmt/utils/decoding.py b/opennmt/utils/decoding.py
index 80bfb30..feaf8f8 100644
--- a/opennmt/utils/decoding.py
+++ b/opennmt/utils/decoding.py
@@ -184,7 +184,8 @@ class GreedySearch(DecodingStrategy):
"""A basic greedy search strategy."""
def _initialize(self, batch_size, start_ids, attention_size=None):
- finished = tf.zeros([batch_size], dtype=tf.bool)
+ # tf.zeros(..., dtype=tf.bool) seems to be unsupported.
+ finished = tf.zeros([batch_size], dtype=tf.int32) == 0
initial_log_probs = tf.zeros([batch_size], dtype=tf.float32)
return start_ids, finished, initial_log_probs, []