Best way to translate locally using Java

I would like to use trained model for translation from my Java code. Unfortunately, using separate server is not an option.

I see 2 solutions to the problem:

  1. use OpenNMT (lua version) for training, CTranslate for translation by calling C++ code via JNI.
  2. use OpenNMT-tf for training, Tensorflow Java API for translation. This seems better, but Tensorflow Java API does not, of course, support OpenNMT-tf.

So could anyone please tell me whether I’ve considered all options and whether it is possible for one person in a couple of days to patch Tensorflow Java API with code for OpenNMT-tf?


The second option should already be possible without patching TensorFlow. OpenNMT-tf can export SavedModels (see the documentation) that can be loaded in Java via the SavedModelBundle class.

Let me know if that works for you!

1 Like

It’s nice to hear that using Tensorflow Java API should be possible. (I tried to load Graph before posting here, and it failed, so I assumed that it’s not supported).
However, SavedModelBundle did not work for me either.
I called SavedModelBundle.load on one of the directories produced by training (with contents of saved_model.pb, assets and variables folders) and it failed with:

2018-04-20 19:47:05.250986: I tensorflow/cc/saved_model/] Loading SavedModel from: deep-api-model
2018-04-20 19:47:05.313399: I tensorflow/cc/saved_model/] Loading SavedModel: fail. Took 62418 microseconds.
Exception in thread "main" org.tensorflow.TensorFlowException: Op type not registered 'GatherTree' in binary running on alex-N56JRH. Make sure the Op and Kernel are registered in the binary running in this process.
	at org.tensorflow.SavedModelBundle.load(Native Method)
	at org.tensorflow.SavedModelBundle.load(
	at com.github.awesomelemon.HelloTF.main(

I use Tensorflow 1.4.0 for both training and translating.

Would you kindly help me to solve this?

Indeed, ops from tf.contrib are not part of the Java binaries. However, it is possible to dynamically load the missing ops from Java, see:

Following the template of this Stack Overflow answer, the path to the library that contains the GatherTree op should look like this:


Alternatively, you can train a Transformer model as it only uses ops from TensorFlow core. It might be simpler and could also yield better performance than the RNNs.

Thanks, I was able to load the model. Now I’m trying to work with it.
To find out input names I used saved_model_cli and got this:

The given SavedModel SignatureDef contains the following input(s):
inputs['length'] tensor_info:
    dtype: DT_INT32
    shape: (-1)
    name: Placeholder_1:0
inputs['tokens'] tensor_info:
    dtype: DT_STRING
    shape: (-1, -1)
    name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['length'] tensor_info:
    dtype: DT_INT32
    shape: (-1, 5)
    name: seq2seq/decoder_1/decoder_1/while/Exit_16:0
outputs['log_probs'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 5)
    name: seq2seq/decoder_1/decoder_1/while/Exit_11:0
outputs['tokens'] tensor_info:
    dtype: DT_STRING
    shape: (-1, 5, -1)
    name: seq2seq/index_to_string_Lookup:0
Method name is: tensorflow/serving/predict

So I tried to pass the input like so:

        Session.Runner runner = model.session().runner();
        String input = "random";
        byte[][][] matrix = new byte[1][1][];
        matrix[0][0] = input.getBytes("UTF-8");
        Tensor<?> inputCallsTensor = Tensor.create(matrix, String.class);
        Tensor<?> inputLengthTensor = Tensor.create(1);
        runner.feed("Placeholder_1:0", inputLengthTensor);
        runner.feed("Placeholder:0", inputCallsTensor);
        List<Tensor<?>> run = runner.fetch("seq2seq/decoder_1/decoder_1/while/Exit_16:0").run();

This works well up to the last line. There an exception is thrown regarding wrong number of dimensions:

Exception in thread "main" java.lang.IllegalArgumentException: Tried to expand dim index 1 for tensor with 0 dimensions.
	 [[Node: seq2seq/decoder_1/tile_batch_2/ExpandDims = ExpandDims[T=DT_INT32, Tdim=DT_INT32, _output_shapes=[[?,1]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_1, seq2seq/decoder_1/tile_batch/ExpandDims_2/dim)]]
	at Method)

I tried to make matrix 2-dimensional to no avail.
Do I understand correctly that input is a list of strings, and should be passed the way I pass it?

Huh, the problem lied not in the String tensor, but in the int one. After replacing Tensor<?> inputLengthTensor = Tensor.create(1); with

int[] ints = {1};
Tensor<?> inputLengthTensor = Tensor.create(ints);

everything worked.
It seems that my troubles come to end. Thanks for the help!
P.S. It would be great if all needed steps were described in the documentation. There’s a number of caveats, like this int thing (which I guess is used for batching?) or the strange names of the inputs (‘placeholder’ in the name is not a friendly API).

Great to hear that you got it working!

The saved_model_cli output actually showed that the input length is a 1D int vector.

For referencing the input tensors, I think the pattern is to lookup their name via the signature_def field in the meta graph. Not sure how it is in Java but you can look at the C++ example:

If you believe all this needs better documentation, I welcome a PR adding a examples/java directory.