Use Matplotlib from Java or Scala with Matplotlib4j

When trying to do machine learning with Java or Scala, I think that many people have experienced that there is no cool graphing tool, and that Python has Matplotlib.

Therefore, I made a library Matplotlib4j that allows you to call Matplotlib from Java, so I would like to introduce it.

How to use

Add library

Here is an example of Java. Of course, it can also be used from other JVM languages such as Scala and Kotlin. An example will be described later.

First, add Matplotlib4j to the Java project where you want to use Matplotlib.

For Maven, add the following dependency.

Maven


<dependency>
    <groupId>com.github.sh0nk</groupId>
    <artifactId>matplotlib4j</artifactId>
    <version>0.4.0</version>
</dependency>

Similarly, in the case of Gradle, it will be as follows.

Gradle


compile 'com.github.sh0nk:matplotlib4j:0.4.0'

Drawing a graph

The usage is similar to the Matplotlib API, so you can write it intuitively. First, create a Plot object, add an arbitrary graph by calling the pyploy method to it, and finally call the show () method. Since it is a Builder pattern, we will add options behind it using IDE completion.

Scatter plot

As a starting point, let's draw a scatter plot.

ScatterPlot


List<Double> x = NumpyUtils.linspace(-3, 3, 100);
List<Double> y = x.stream().map(xi -> Math.sin(xi) + Math.random()).collect(Collectors.toList());

Plot plt = Plot.create();
plt.plot().add(x, y, "o").label("sin");
plt.legend().loc("upper right");
plt.title("scatter");
plt.show();

Some Numpy methods, such as linspace and meshgrid, are provided as NumpyUtils classes to help you draw graphs. We are generating x and y data to plot in the first block. Here, a random value is given to the sin curve. After that, create a Plot object, add the generated x and y to the plot () method, and finally call show () to draw the graph.

This is roughly equivalent to the Python implementation below (almost because the data generation part of numpy is strictly different). The methods are called similar and are easier to use for Pythonista.

PythonScatterPlot


import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-3, 3, 100)
y = np.sin(x) + np.random.rand(100)

plt.plot(x, y, "o", label="sin")
plt.legend(loc="upper right")
plt.title("scatter")
plt.show()

With the above Java code, you can draw the following graph.

scatter.png

Contour diagram

Next, let's draw a contour diagram (contour line).

ContourPlot


List<Double> x = NumpyUtils.linspace(-1, 1, 100);
List<Double> y = NumpyUtils.linspace(-1, 1, 100);
NumpyUtils.Grid<Double> grid = NumpyUtils.meshgrid(x, y);

List<List<Double>> zCalced = grid.calcZ((xi, yj) -> Math.sqrt(xi * xi + yj * yj));

Plot plt = Plot.create();
ContourBuilder contour = plt.contour().add(x, y, zCalced);
plt.clabel(contour)
    .inline(true)
    .fontsize(10);
plt.title("contour");
plt.show();
contour.png

histogram

You can draw a histogram in the same way.

HistogramPlot


Random rand = new Random();
List<Double> x1 = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
        .collect(Collectors.toList());
List<Double> x2 = IntStream.range(0, 1000).mapToObj(i -> 4.0 + rand.nextGaussian())
        .collect(Collectors.toList());

Plot plt = Plot.create();
plt.hist()
    .add(x1).add(x2)
    .bins(20)
    .stacked(true)
    .color("#66DD66", "#6688FF");
plt.xlim(-6, 10);
plt.title("histogram");
plt.show();
histogram.png

Save image to file

Matplotlib4j also supports saving to a file. Image file saving is convenient for use cases that do not have a GUI, such as periodic processing of machine learning on a server.

Like the original Matplotlib, by using the .savefig () method instead of .show (), the image is saved to a file instead of popping up the plot window. The only difference is that .savefig () is followed by plt.executeSilently (). (This is necessary as termination processing because the savefig command is also connected by a method chain.)

Random rand = new Random();
List<Double> x = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
        .collect(Collectors.toList());

Plot plt = Plot.create();
plt.hist().add(x).orientation(HistBuilder.Orientation.horizontal);
plt.ylim(-5, 5);
plt.title("histogram");
plt.savefig("/tmp/histogram.png ").dpi(200);

//Required to output the file
plt.executeSilently();

As a result, the following image will be output.

histogram.png

Switching Python with pyenv, pyenv-virtualenv

To use Matplotlib4j, you need to use Python with Matplotlib installed. By default, Matplotlib4j uses Python that is in the path, but I think that there are many cases where Matplotlib is not installed in the system default Python.

In that case, you can switch to a Python environment with Matplotlib installed, such as Anaconda, using pyenv or pyenv-virtualenv.

To use Python according to the environment of Pyenv, specify PythonConfig as follows when creating a Plot object.

pyenv


Plot plot = Plot.create(PythonConfig.pyenvConfig("Any pyenv environment name"));

Similarly, you can specify the environment name for pyenv-virtualenv.

pyenv-virtualenv


Plot plot = Plot.create(PythonConfig.pyenvVirtualenvConfig("Any pyenv environment name", "Any virtualenv environment name"));

Scala

When using from Scala, the above scatter plot example can be written as follows. At that time, pay attention to the difference between Box / Unbox and List classes.

ScalaScatter


import scala.collection.JavaConverters._

val x = NumpyUtils.linspace(-3, 3, 100).asScala.toList
val y = x.map(xi => Math.sin(xi) + Math.random()).map(Double.box)

val plt = Plot.create()
plt.plot().add(x.asJava, y.asJava, "o")
plt.title("scatter")
plt.show()

bonus

Trigger

I recently started reading "Deep Learning from scratch-The theory and implementation of deep learning learned with Python", but with Python as it is It's not fun to copy it, so I decided to implement it with Scala, which I've been touching a lot recently. I was able to write it like a function in Scala, and I was very happy with it, but only when I approached backpropagation by the steepest descent method, Loss did not go down at all, isn't there a bug somewhere? I ran into the situation.

Of course, I think the usual way to do this is to thicken the test, but first I'd like to quickly display the graph as in the book to see what's going on. But with Scala, there is no cool graph tool ... However, implementing the graph tool from scratch is too terrible ... So, I decided to use Matplotlib, which is familiar to Python, as the reason for creating the library.

design

Matplotlib4j calls Matplotlib in the form of generating Python code without using JNI or Jython. At first I thought I would implement it using Jython, but in the first place Python version only supports up to 2.7, and numpy cannot be used, so it depends on it. I decided to give up this road because Matplotlib does not work either.

There is also a library that enables CPython to be used from Java code, and this is a candidate because both Python3 and numpy can be used. However, in order to use JNI, a separate environment-dependent library is installed, and on the Python side, it is necessary to install the library from pip, so it takes too much time to use it just to draw a graph. After all, I decided to implement it independently of these libraries.

Of course, since it is executed via a file, it is necessary to devise how to pass variables and use the return value, and is performance okay? I am worried about it. Fortunately, since the purpose is only to draw a graph, the basic functions can be satisfied by unilaterally outputting to a file, and I think that some waiting time for performance is within the allowable range.

Recommended Posts

Use Matplotlib from Java or Scala with Matplotlib4j
Use JDBC with Java and Scala.
Use native libraries from Scala via Java CPP + Java
Use Lambda Layers with Java
Use SpatiaLite with Java / JDBC
[Java, Scala] Image resizing with ImageIO
Use java with MSYS and Cygwin
Code Java from Emacs with Eclim
Use Microsoft Graph with standard Java
Use Azure Bing SpellCheck with Java
Use Java 11 with Google Cloud Functions
Use Chrome Headless from Selenium / Java
Work with Google Sheets from Java
Call TensorFlow Java API from Scala
Whether to use Java Comparable or Comparator
Call Java library from C with JNI
API integration from Java with Jersey Client
Getting Started with Java Starting from 0 Part 1
[JaCoCo (Java Code Coverage)] Use with NetBeans
Execute Java code from cpp with cocos2dx
Run Rust from Java with JNA (Java Native Access)
[Java] Set the time from the browser with jsoup
How to use Java framework with AWS Lambda! ??
Text extraction in Java from PDF with pdfbox-2.0.8
How to use Java API with lambda expression
Create Scala Seq from Java, make Scala Seq a Java List
Get unixtime (seconds) from ZonedDateTime in Scala / Java
[JAVA] [Spring] [MyBatis] Use IN () with SQL Builder
[Java] Use Collectors.collectingAndThen
How to write Scala from the perspective of Java
Behavior when calling Java variadic methods from Scala / Kotlin / Java
Call a method with a Kotlin callback block from Java
Use aggregate queries (Count) with Azure CosmosDB Java SDK
Java getClass () is incompatible with generics (or type variables)
6 features I missed after returning to Java from Scala
Play with Java function nodes that can use Java with Node-RED
[Note] Create a java environment from scratch with docker
How to use JDD library in Scala with Eclipse
Use java1.7 (zulu7) under a specific directory with jenv
Read temperature / humidity with Java from Raspberry Pi 3 & DHT11
JAWJAW is convenient if you use WordNet from Java