This time, I will introduce quantum-inspired machine learning, which has become a little hot in recent years. To be more precise, "quantum inspire" here is "inspired by the method used to simulate quantum systems as efficiently as possible in classical computation". Specifically, it is a tensor network.
There has been a long-standing trend of using tensor networks to simulate quantum systems using classical systems. Efficiently compute the state of many-body quantum systems using matrix product states [1], efficiently simulate gate quantum computations with a combination of undirected graphs and tensor network reductions [2], etc. ..
As these methods show, by using a tensor network, the very high-dimensional space of a quantum system is only approximate, but it can be handled by classical calculations. Being high-dimensional leads to high expressiveness in machine learning.
"Quantum-inspired machine learning" is to apply this feature not only to the simulation of quantum systems but also to classical machine learning problems.
The matrix product state is called MPS and is represented as shown in the figure below [1].
Each black circle is called a "site", and if it is an N qubit system, N sites will be created.
You can also imagine that each site has a two-dimensional matrix corresponding to physical index = 0 and a two-dimensional matrix corresponding to physical index = 1. If $ \ sigma_i = 0 $ for all $ i $, then the product of the two-dimensional matrices corresponding to the physical index = 0 of all sites can be calculated, and the result is $ in the original quantum state. | 00 ... 0> $ coefficient.
Introduced based on typical papers [3] and [4] dealing with classification problems by supervised learning. The overall flow is shown in the figure.
First, encode the input data $ x $ into $ \ sigma_i \ (i \ in 0, ..., n-1) $ in the matrix product state diagram.
Takes a tensor contraction between the encoded data and the MPS. Furthermore, the edges between the MPS sites are reduced, but as it is, only one scalar value can be obtained, so it cannot be used for classification. Therefore, MPS has a "label index" in advance to output the value corresponding to the probability that $ x $ belongs to each class. Have the label index on one existing site or one newly added site to keep the label index. This way, the calculation of the reductions of $ \ sigma_i $ and all edges of the MPS will leave a tensor with the same number of discriminant classes and equal elements, so you can enter that value into the loss function.
When learning, update each element of MPS so that the output of the loss function is small. There are two main renewal policies. One is the method adopted in [3], which is an application of the conventional method called DMRG. Repeat while sweeping the update by local optimization with only two adjacent sites as variables. The other is adopted in [4] and uses the error backpropagation method to update all elements of the MPS.
The former method has the advantage that extra dimensions can be dynamically trimmed using SVD when updating. On the other hand, the latter method is compatible with the calculation by the existing DL and automatic differentiation framework, and probably has a high degree of freedom in defining the network structure and loss function.
This time, we implemented MNIST learning by the error back propagation method, which was performed in [4]. For the implementation, we used a python module called Tensornetwork developed by the authors.
Tensornetwork is a library that is literally suitable for calculating tensor networks. You can select "tensorflow" and "jax" as the backend. If you select "tensorflow", you can study in combination with the Tensorflow framework. It is convenient to be able to use Tensorflow's automatic differentiation function and built-in functions, but on the other hand, when writing it, most of it is occupied by custom layers, so it is troublesome to write according to the framework and the overhead of the framework itself There are also aspects of concern.
So this time we are using the jax backend. In fact, the study [5] following [4] seems to use the jax backend. jax is also a python framework, roughly like numpy, which supports parallel computing with automatic differentiation, JIT, and vectorization. It might be a good option if you just want to use Tensorflow's high-speed automatic differentiation simply (I think Julia's Flux etc. is in a similar position, and there is a certain demand for it).
My implementation is slightly different from [4] in the following points.
Regarding 1., it was difficult to learn with the original size. When reducing the number of pixels, the number of matrices is multiplied by the number of pixels, and as the number of matrices increases, the output value tends to diverge or converge to 0, and the gradient tends to disappear, which is a practical difficulty. there is. I think it depends on the adjustment, but this time I compromised. The authors are also pooling in [5](maybe because the network structure and tasks are somewhat different).
I put the implementation code below. https://github.com/ryuNagai/MPS/blob/master/TN_ML/MNIST_ML_jax.ipynb
The learning process is like this.
Finally, train accuracy = 0.962 and test accuracy = 0.952. In [4], the train accuracy reached about 0.98 at about 50 epoch, which was not enough to reproduce it. Behind the scenes, I tried a little to see if the result of [4] could be reproduced, but it was difficult, so I'm happy with this value.
We have implemented quantum-inspired machine learning using a new (potentially) popular tensor network. In the current situation where there are many restrictions on the hardware side of quantum computers, this method can be executed on a classical computer, so it can handle big problems. I think it is still up to future research to discover more useful models than conventional machine learning models using this method.
In addition, if there is a possibility that it can be verified, whether or not machine learning using quantum space has an advantage over classical machine learning, by using such a method, it can be seen approximately or indirectly. I think it's good.
[1] https://arxiv.org/abs/1008.3477 [2] https://arxiv.org/abs/1805.01450 [3] https://papers.nips.cc/paper/6211-supervised-learning-with-tensor-networks [4] https://arxiv.org/abs/1906.06329 [5] https://arxiv.org/abs/2006.02516
Recommended Posts