I would like to use trained model for translation from my Java code. Unfortunately, using separate server is not an option.
I see 2 solutions to the problem:
- use OpenNMT (lua version) for training, CTranslate for translation by calling C++ code via JNI.
- use OpenNMT-tf for training, Tensorflow Java API for translation. This seems better, but Tensorflow Java API does not, of course, support OpenNMT-tf.
So could anyone please tell me whether I’ve considered all options and whether it is possible for one person in a couple of days to patch Tensorflow Java API with code for OpenNMT-tf?
Hi,
The second option should already be possible without patching TensorFlow. OpenNMT-tf can export SavedModel
s (see the documentation) that can be loaded in Java via the SavedModelBundle class.
Let me know if that works for you!
1 Like
Hello.
It’s nice to hear that using Tensorflow Java API should be possible. (I tried to load Graph before posting here, and it failed, so I assumed that it’s not supported).
However, SavedModelBundle did not work for me either.
I called SavedModelBundle.load
on one of the directories produced by training (with contents of saved_model.pb, assets and variables folders) and it failed with:
2018-04-20 19:47:05.250986: I tensorflow/cc/saved_model/loader.cc:236] Loading SavedModel from: deep-api-model
2018-04-20 19:47:05.313399: I tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: fail. Took 62418 microseconds.
Exception in thread "main" org.tensorflow.TensorFlowException: Op type not registered 'GatherTree' in binary running on alex-N56JRH. Make sure the Op and Kernel are registered in the binary running in this process.
at org.tensorflow.SavedModelBundle.load(Native Method)
at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:39)
at com.github.awesomelemon.HelloTF.main(HelloTF.java:14)
I use Tensorflow 1.4.0 for both training and translating.
Would you kindly help me to solve this?
Indeed, ops from tf.contrib
are not part of the Java binaries. However, it is possible to dynamically load the missing ops from Java, see:
Following the template of this Stack Overflow answer, the path to the library that contains the GatherTree
op should look like this:
/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/seq2seq/python/ops/_beam_search_ops.so
Alternatively, you can train a Transformer model as it only uses ops from TensorFlow core. It might be simpler and could also yield better performance than the RNNs.
Thanks, I was able to load the model. Now I’m trying to work with it.
To find out input names I used saved_model_cli
and got this:
The given SavedModel SignatureDef contains the following input(s):
inputs['length'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: Placeholder_1:0
inputs['tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['length'] tensor_info:
dtype: DT_INT32
shape: (-1, 5)
name: seq2seq/decoder_1/decoder_1/while/Exit_16:0
outputs['log_probs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 5)
name: seq2seq/decoder_1/decoder_1/while/Exit_11:0
outputs['tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, 5, -1)
name: seq2seq/index_to_string_Lookup:0
Method name is: tensorflow/serving/predict
So I tried to pass the input like so:
Session.Runner runner = model.session().runner();
String input = "random";
byte[][][] matrix = new byte[1][1][];
matrix[0][0] = input.getBytes("UTF-8");
Tensor<?> inputCallsTensor = Tensor.create(matrix, String.class);
Tensor<?> inputLengthTensor = Tensor.create(1);
runner.feed("Placeholder_1:0", inputLengthTensor);
runner.feed("Placeholder:0", inputCallsTensor);
List<Tensor<?>> run = runner.fetch("seq2seq/decoder_1/decoder_1/while/Exit_16:0").run();
This works well up to the last line. There an exception is thrown regarding wrong number of dimensions:
Exception in thread "main" java.lang.IllegalArgumentException: Tried to expand dim index 1 for tensor with 0 dimensions.
[[Node: seq2seq/decoder_1/tile_batch_2/ExpandDims = ExpandDims[T=DT_INT32, Tdim=DT_INT32, _output_shapes=[[?,1]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_1, seq2seq/decoder_1/tile_batch/ExpandDims_2/dim)]]
at org.tensorflow.Session.run(Native Method)
I tried to make matrix
2-dimensional to no avail.
Do I understand correctly that input is a list of strings, and should be passed the way I pass it?
Huh, the problem lied not in the String tensor, but in the int one. After replacing Tensor<?> inputLengthTensor = Tensor.create(1);
with
int[] ints = {1};
Tensor<?> inputLengthTensor = Tensor.create(ints);
everything worked.
It seems that my troubles come to end. Thanks for the help!
P.S. It would be great if all needed steps were described in the documentation. There’s a number of caveats, like this int
thing (which I guess is used for batching?) or the strange names of the inputs (‘placeholder’ in the name is not a friendly API).
Great to hear that you got it working!
The saved_model_cli
output actually showed that the input length is a 1D int vector.
For referencing the input tensors, I think the pattern is to lookup their name via the signature_def
field in the meta graph. Not sure how it is in Java but you can look at the C++ example:
If you believe all this needs better documentation, I welcome a PR adding a examples/java
directory.