When trying to do machine learning with Java or Scala, I think that many people have experienced that there is no cool graphing tool, and that Python has Matplotlib.
Therefore, I made a library Matplotlib4j that allows you to call Matplotlib from Java, so I would like to introduce it.
Here is an example of Java. Of course, it can also be used from other JVM languages such as Scala and Kotlin. An example will be described later.
First, add Matplotlib4j to the Java project where you want to use Matplotlib.
For Maven, add the following dependency.
Maven
<dependency>
<groupId>com.github.sh0nk</groupId>
<artifactId>matplotlib4j</artifactId>
<version>0.4.0</version>
</dependency>
Similarly, in the case of Gradle, it will be as follows.
Gradle
compile 'com.github.sh0nk:matplotlib4j:0.4.0'
The usage is similar to the Matplotlib API, so you can write it intuitively. First, create a Plot object, add an arbitrary graph by calling the pyploy method to it, and finally call the show () method. Since it is a Builder pattern, we will add options behind it using IDE completion.
As a starting point, let's draw a scatter plot.
ScatterPlot
List<Double> x = NumpyUtils.linspace(-3, 3, 100);
List<Double> y = x.stream().map(xi -> Math.sin(xi) + Math.random()).collect(Collectors.toList());
Plot plt = Plot.create();
plt.plot().add(x, y, "o").label("sin");
plt.legend().loc("upper right");
plt.title("scatter");
plt.show();
Some Numpy methods, such as linspace
and meshgrid
, are provided as NumpyUtils
classes to help you draw graphs. We are generating x and y data to plot in the first block. Here, a random value is given to the sin curve. After that, create a Plot object, add the generated x and y to the plot ()
method, and finally call show ()
to draw the graph.
This is roughly equivalent to the Python implementation below (almost because the data generation part of numpy is strictly different). The methods are called similar and are easier to use for Pythonista.
PythonScatterPlot
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-3, 3, 100)
y = np.sin(x) + np.random.rand(100)
plt.plot(x, y, "o", label="sin")
plt.legend(loc="upper right")
plt.title("scatter")
plt.show()
With the above Java code, you can draw the following graph.
Next, let's draw a contour diagram (contour line).
ContourPlot
List<Double> x = NumpyUtils.linspace(-1, 1, 100);
List<Double> y = NumpyUtils.linspace(-1, 1, 100);
NumpyUtils.Grid<Double> grid = NumpyUtils.meshgrid(x, y);
List<List<Double>> zCalced = grid.calcZ((xi, yj) -> Math.sqrt(xi * xi + yj * yj));
Plot plt = Plot.create();
ContourBuilder contour = plt.contour().add(x, y, zCalced);
plt.clabel(contour)
.inline(true)
.fontsize(10);
plt.title("contour");
plt.show();
You can draw a histogram in the same way.
HistogramPlot
Random rand = new Random();
List<Double> x1 = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
.collect(Collectors.toList());
List<Double> x2 = IntStream.range(0, 1000).mapToObj(i -> 4.0 + rand.nextGaussian())
.collect(Collectors.toList());
Plot plt = Plot.create();
plt.hist()
.add(x1).add(x2)
.bins(20)
.stacked(true)
.color("#66DD66", "#6688FF");
plt.xlim(-6, 10);
plt.title("histogram");
plt.show();
Matplotlib4j also supports saving to a file. Image file saving is convenient for use cases that do not have a GUI, such as periodic processing of machine learning on a server.
Like the original Matplotlib, by using the .savefig ()
method instead of .show ()
, the image is saved to a file instead of popping up the plot window. The only difference is that .savefig ()
is followed by plt.executeSilently ()
. (This is necessary as termination processing because the savefig command is also connected by a method chain.)
Random rand = new Random();
List<Double> x = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
.collect(Collectors.toList());
Plot plt = Plot.create();
plt.hist().add(x).orientation(HistBuilder.Orientation.horizontal);
plt.ylim(-5, 5);
plt.title("histogram");
plt.savefig("/tmp/histogram.png ").dpi(200);
//Required to output the file
plt.executeSilently();
As a result, the following image will be output.
To use Matplotlib4j, you need to use Python with Matplotlib installed. By default, Matplotlib4j uses Python that is in the path, but I think that there are many cases where Matplotlib is not installed in the system default Python.
In that case, you can switch to a Python environment with Matplotlib installed, such as Anaconda, using pyenv or pyenv-virtualenv.
To use Python according to the environment of Pyenv, specify PythonConfig as follows when creating a Plot object.
pyenv
Plot plot = Plot.create(PythonConfig.pyenvConfig("Any pyenv environment name"));
Similarly, you can specify the environment name for pyenv-virtualenv.
pyenv-virtualenv
Plot plot = Plot.create(PythonConfig.pyenvVirtualenvConfig("Any pyenv environment name", "Any virtualenv environment name"));
Scala
When using from Scala, the above scatter plot example can be written as follows. At that time, pay attention to the difference between Box / Unbox and List classes.
ScalaScatter
import scala.collection.JavaConverters._
val x = NumpyUtils.linspace(-3, 3, 100).asScala.toList
val y = x.map(xi => Math.sin(xi) + Math.random()).map(Double.box)
val plt = Plot.create()
plt.plot().add(x.asJava, y.asJava, "o")
plt.title("scatter")
plt.show()
I recently started reading "Deep Learning from scratch-The theory and implementation of deep learning learned with Python", but with Python as it is It's not fun to copy it, so I decided to implement it with Scala, which I've been touching a lot recently. I was able to write it like a function in Scala, and I was very happy with it, but only when I approached backpropagation by the steepest descent method, Loss did not go down at all, isn't there a bug somewhere? I ran into the situation.
Of course, I think the usual way to do this is to thicken the test, but first I'd like to quickly display the graph as in the book to see what's going on. But with Scala, there is no cool graph tool ... However, implementing the graph tool from scratch is too terrible ... So, I decided to use Matplotlib, which is familiar to Python, as the reason for creating the library.
Matplotlib4j calls Matplotlib in the form of generating Python code without using JNI or Jython. At first I thought I would implement it using Jython, but in the first place Python version only supports up to 2.7, and numpy cannot be used, so it depends on it. I decided to give up this road because Matplotlib does not work either.
There is also a library that enables CPython to be used from Java code, and this is a candidate because both Python3 and numpy can be used. However, in order to use JNI, a separate environment-dependent library is installed, and on the Python side, it is necessary to install the library from pip, so it takes too much time to use it just to draw a graph. After all, I decided to implement it independently of these libraries.
Of course, since it is executed via a file, it is necessary to devise how to pass variables and use the return value, and is performance okay? I am worried about it. Fortunately, since the purpose is only to draw a graph, the basic functions can be satisfied by unilaterally outputting to a file, and I think that some waiting time for performance is within the allowable range.
Recommended Posts