Best way to translate locally using Java

ctranslate
tensorflow

(Alexander Chebykin) #1

I would like to use trained model for translation from my Java code. Unfortunately, using separate server is not an option.

I see 2 solutions to the problem:

  1. use OpenNMT (lua version) for training, CTranslate for translation by calling C++ code via JNI.
  2. use OpenNMT-tf for training, Tensorflow Java API for translation. This seems better, but Tensorflow Java API does not, of course, support OpenNMT-tf.

So could anyone please tell me whether I’ve considered all options and whether it is possible for one person in a couple of days to patch Tensorflow Java API with code for OpenNMT-tf?


(Guillaume Klein) #2

Hi,

The second option should already be possible without patching TensorFlow. OpenNMT-tf can export SavedModels (see the documentation) that can be loaded in Java via the SavedModelBundle class.

Let me know if that works for you!


(Alexander Chebykin) #3

Hello.
It’s nice to hear that using Tensorflow Java API should be possible. (I tried to load Graph before posting here, and it failed, so I assumed that it’s not supported).
However, SavedModelBundle did not work for me either.
I called SavedModelBundle.load on one of the directories produced by training (with contents of saved_model.pb, assets and variables folders) and it failed with:

2018-04-20 19:47:05.250986: I tensorflow/cc/saved_model/loader.cc:236] Loading SavedModel from: deep-api-model
2018-04-20 19:47:05.313399: I tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: fail. Took 62418 microseconds.
Exception in thread "main" org.tensorflow.TensorFlowException: Op type not registered 'GatherTree' in binary running on alex-N56JRH. Make sure the Op and Kernel are registered in the binary running in this process.
	at org.tensorflow.SavedModelBundle.load(Native Method)
	at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:39)
	at com.github.awesomelemon.HelloTF.main(HelloTF.java:14)

I use Tensorflow 1.4.0 for both training and translating.

Would you kindly help me to solve this?


(Guillaume Klein) #4

Indeed, ops from tf.contrib are not part of the Java binaries. However, it is possible to dynamically load the missing ops from Java, see:

Following the template of this Stack Overflow answer, the path to the library that contains the GatherTree op should look like this:

/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/seq2seq/python/ops/_beam_search_ops.so

Alternatively, you can train a Transformer model as it only uses ops from TensorFlow core. It might be simpler and could also yield better performance than the RNNs.


(Alexander Chebykin) #5

Thanks, I was able to load the model. Now I’m trying to work with it.
To find out input names I used saved_model_cli and got this:

The given SavedModel SignatureDef contains the following input(s):
inputs['length'] tensor_info:
    dtype: DT_INT32
    shape: (-1)
    name: Placeholder_1:0
inputs['tokens'] tensor_info:
    dtype: DT_STRING
    shape: (-1, -1)
    name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['length'] tensor_info:
    dtype: DT_INT32
    shape: (-1, 5)
    name: seq2seq/decoder_1/decoder_1/while/Exit_16:0
outputs['log_probs'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 5)
    name: seq2seq/decoder_1/decoder_1/while/Exit_11:0
outputs['tokens'] tensor_info:
    dtype: DT_STRING
    shape: (-1, 5, -1)
    name: seq2seq/index_to_string_Lookup:0
Method name is: tensorflow/serving/predict

So I tried to pass the input like so:

        Session.Runner runner = model.session().runner();
        
        String input = "random";
        byte[][][] matrix = new byte[1][1][];
        matrix[0][0] = input.getBytes("UTF-8");
        Tensor<?> inputCallsTensor = Tensor.create(matrix, String.class);
        
        Tensor<?> inputLengthTensor = Tensor.create(1);
        runner.feed("Placeholder_1:0", inputLengthTensor);
        runner.feed("Placeholder:0", inputCallsTensor);
        List<Tensor<?>> run = runner.fetch("seq2seq/decoder_1/decoder_1/while/Exit_16:0").run();

This works well up to the last line. There an exception is thrown regarding wrong number of dimensions:

Exception in thread "main" java.lang.IllegalArgumentException: Tried to expand dim index 1 for tensor with 0 dimensions.
	 [[Node: seq2seq/decoder_1/tile_batch_2/ExpandDims = ExpandDims[T=DT_INT32, Tdim=DT_INT32, _output_shapes=[[?,1]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_1, seq2seq/decoder_1/tile_batch/ExpandDims_2/dim)]]
	at org.tensorflow.Session.run(Native Method)

I tried to make matrix 2-dimensional to no avail.
Do I understand correctly that input is a list of strings, and should be passed the way I pass it?


(Alexander Chebykin) #6

Huh, the problem lied not in the String tensor, but in the int one. After replacing Tensor<?> inputLengthTensor = Tensor.create(1); with

int[] ints = {1};
Tensor<?> inputLengthTensor = Tensor.create(ints);

everything worked.
It seems that my troubles come to end. Thanks for the help!
P.S. It would be great if all needed steps were described in the documentation. There’s a number of caveats, like this int thing (which I guess is used for batching?) or the strange names of the inputs (‘placeholder’ in the name is not a friendly API).


(Guillaume Klein) #7

Great to hear that you got it working!

The saved_model_cli output actually showed that the input length is a 1D int vector.

For referencing the input tensors, I think the pattern is to lookup their name via the signature_def field in the meta graph. Not sure how it is in Java but you can look at the C++ example:

If you believe all this needs better documentation, I welcome a PR adding a examples/java directory.