[JAVA] Use TensorFlow from JRuby

Piyo7 posted Call TensorFlow Java API from Scala a few days ago. I tried to see what kind of code it would be from JRuby.

code

example.rb


require 'pp'
require 'java'
require './libtensorflow-1.0.0-PREVIEW1.jar'

module TF
  include_package 'org.tensorflow'
end

graph = TF::Graph.new

a = graph.opBuilder("Const", "a").
  setAttr("dtype", TF::DataType::INT32).
  setAttr("value", TF::Tensor.create([1, 2, 3].to_java(:int))).
  build().
  output(0)

b = graph.opBuilder("Const", "b").
  setAttr("dtype", TF::DataType::INT32).
  setAttr("value", TF::Tensor.create([4, 5, 6].to_java(:int))).
  build().
  output(0)

c = graph.opBuilder("Mul", "c").
  addInput(a).
  addInput(b).
  build().
  output(0)

session = TF::Session.new(graph)
out = Array.new(3).to_java(:int)
session.runner().fetch("c").run().get(0).copyTo(out)

pp out #=> int[4, 10, 18]@71623278

$ ruby -J-Djava.library.path=./jni example.rb

If you run it with, you will certainly get [4, 10, 18].

I was a little addicted to the process of importing org.tensorflow

  import 'org.tensorflow.*'

Instead of writing, it became a method to define a namespace TF and include_package.

Comparison

A library that runs TensorFlow from Ruby was released last June

somaticio/tensorflow.rb

Seems to be a classic.

Comparing the sample code that comes with this library with the above example.rb, example.rb is considerably more verbose. is.

If it makes sense to evolve this code, it would just replace module TF with a more intelligent wrapper.

(1) [Metaprogramming] module TF using symbol information on Java side (http://qiita.com/tags/%E3%83%A1%E3%82%BF%E3%83%97%E3% 83% AD% E3% 82% B0% E3% 83% A9% E3% 83% 9F% E3% 83% B3% E3% 82% B0).

(2) Keras Provides a more abstract (Ruby-like) API like.

It is possible that.

Keras will be integrated into TensorFlow with Keras 2 (→ Spring 2017 roadmap: Keras 2, PR freeze, TF integration), Java The API specification seems to be still unstable. Unfortunately, both (1) and (2) are premature.

Recommended Posts

Use TensorFlow from JRuby
Call Java from JRuby
Use Face API from Ruby
From introduction to use of ActiveHash
Use database user-defined functions from JPQL
Call TensorFlow Java API from Scala
Use PostgreSQL data type (jsonb) from Java
[Flutter] How to use C / C ++ from Dart?