I have implemented a dual-source transformer model for a translation task and in order to improve its performance, I plan to incorporate POS tag data in both source input sentences. According to my understanding, this can be done if I add the POS data as features for both sources. I tried to find a similar implementation to study first as I am new to this field but couldn’t find any. Therefore,It would be really helpful if I anyone can suggest me a reference for such an implementation or any other suggestions to implement above task. Thank you
You need to add a level of input nesting to your model and configuration.
Here’s how the custom model definition could look like. It defines 2 input sources where each one has 2 input features (tokens and POS tags). The embeddings of the features are concatenated:
import opennmt as onmt
from opennmt.utils import misc
class DualSourceTransformer(onmt.models.Transformer):
def __init__(self):
super().__init__(
source_inputter=onmt.inputters.ParallelInputter(
[
onmt.inputters.ParallelInputter(
[
onmt.inputters.WordEmbedder(embedding_size=480),
onmt.inputters.WordEmbedder(embedding_size=32),
],
reducer=onmt.layers.ConcatReducer(),
),
onmt.inputters.ParallelInputter(
[
onmt.inputters.WordEmbedder(embedding_size=480),
onmt.inputters.WordEmbedder(embedding_size=32),
],
reducer=onmt.layers.ConcatReducer(),
),
]
),
target_inputter=onmt.inputters.WordEmbedder(embedding_size=512),
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
ffn_dropout=0.1,
share_encoders=True,
)
def auto_config(self, num_replicas=1):
config = super().auto_config(num_replicas=num_replicas)
max_length = config["train"]["maximum_features_length"]
return misc.merge_dict(
config, {"train": {"maximum_features_length": [max_length, max_length]}}
)
model = DualSourceTransformer
The data configuration should have the same level of nesting:
data:
train_features_file:
- - source_1.txt
- source_1.txt.pos
- - source_2.txt
- source_2.txt.pos
See the related documentation: Data — OpenNMT-tf 2.15.0 documentation
Thank you so much for this clarification. I was not sure how to incorporate the features in the ParallelInputter module. This is really helpful.
Hi, I tried to Implement this model using POS tag information only for one source input as follows,
class DualSourceTransformer(onmt.models.Transformer):
def __init__(self):
super(DualSourceTransformer, self).__init__(
source_inputter=onmt.inputters.ParallelInputter(
[
onmt.inputters.ParallelInputter(
[
onmt.inputters.WordEmbedder(embedding_size=480),
onmt.inputters.WordEmbedder(embedding_size=32),
],
reducer=onmt.layers.ConcatReducer(),
),
onmt.inputters.WordEmbedder(embedding_size=512),
]),
target_inputter=onmt.inputters.WordEmbedder(embedding_size=512),
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
ffn_dropout=0.1,
share_encoders=True)
def auto_config(self, num_replicas=1):
config = super(DualSourceTransformer, self).auto_config(num_replicas=num_replicas)
max_length = config["train"]["maximum_features_length"]
return misc.merge_dict(config, {
"train": {
"maximum_features_length": [max_length, max_length]
}
})
and the corresponding config file was created as follows.
data:
train_features_file:
- - src_1_train.txt
- src_1_train.txt.pos
- src_2_train.txt
train_labels_file: tgt_train.txt
eval_features_file:
- - src_1_val.txt
- src_1_val.txt.pos
- src_2_val.txt
eval_labels_file: tgt_val.txt
But when the training begins, it returns an error as follows. It mentions about a shape mismatch, but I was not able to find what causes this mismatch.
`tf.data.TFRecordDataset(path)`
INFO:tensorflow:Training on 74468 examples
2021-02-23 10:18:59.597253: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-02-23 10:18:59.601808: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2199995000 Hz
INFO:tensorflow:Number of model parameters: 117466697
INFO:tensorflow:Number of model weights: 322 (trainable = 322, non trainable = 0)
2021-02-23 10:19:50.828144: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10
2021-02-23 10:19:51.382938: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
Traceback (most recent call last):
File "/usr/local/bin/onmt-main", line 33, in <module>
sys.exit(load_entry_point('OpenNMT-tf==2.15.0', 'console_scripts', 'onmt-main')())
File "/usr/local/lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/bin/main.py", line 323, in main
hvd=hvd,
File "/usr/local/lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/runner.py", line 273, in train
moving_average_decay=train_config.get("moving_average_decay"),
File "/usr/local/lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/training.py", line 109, in __call__
dataset, accum_steps=accum_steps, report_steps=report_steps
File "/usr/local/lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/training.py", line 248, in _steps
loss = forward_fn()
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: ConcatOp : Dimensions of inputs should match: shape[0] = [64,2,480] vs. shape[1] = [64,3,32]
[[node dual_source_transformer/parallel_inputter_1/parallel_inputter/concat_reducer_6/concat (defined at /lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/layers/reducer.py:187) ]]
[[Where/_44]]
(1) Invalid argument: ConcatOp : Dimensions of inputs should match: shape[0] = [64,2,480] vs. shape[1] = [64,3,32]
[[node dual_source_transformer/parallel_inputter_1/parallel_inputter/concat_reducer_6/concat (defined at /lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/layers/reducer.py:187) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__forward_47800]
Errors may have originated from an input operation.
Input Source operations connected to node dual_source_transformer/parallel_inputter_1/parallel_inputter/concat_reducer_6/concat:
dual_source_transformer/parallel_inputter_1/parallel_inputter/word_embedder/embedding_lookup/Identity (defined at /lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/inputters/text_inputter.py:441)
Input Source operations connected to node dual_source_transformer/parallel_inputter_1/parallel_inputter/concat_reducer_6/concat:
dual_source_transformer/parallel_inputter_1/parallel_inputter/word_embedder/embedding_lookup/Identity (defined at /lib/python3.6/dist-packages/OpenNMT_tf-2.15.0-py3.6.egg/opennmt/inputters/text_inputter.py:441)
Function call stack:
_forward -> _forward
If anyone can help me to find cause of this error, it will be a great help. Thank you.
Can you check that the files containing the tokens and the POS tags are correctly aligned? There should be the same number of POS tags as there are tokens.
Hi @guillaumekln. Thank you for showing the multi source Transformer in OpenNMT-tf. We are working on Multilingual Indic-English translation using OpenNMT-py version. Is multi source Transformer available in pytorch version. It will be really helpful for us. Thank you.