Differentiation of sort and generalization of sort

0. Introduction

This article describes how to "differentiate" a "sort" and generalize the sort. By considering the generalization of sorting, we can also calculate the "derivative" of the quantile function.

In addition, we will actually experiment with a machine learning task called least quantile regression using the differentiation and generalization of sorts.

This is an introduction to the content of the next paper recently announced by Google and attracting attention.

Differentiable Ranks and Sorting using Optimal Transport

Since it uses a technique similar to Article that differentiates transportation problems, some explanations will be given.

Differentiation of sort?

When we use sort, we often use it in the form of two functions, sort function $ S (x) $ and rank function $ R (x) $.

x=(x_1,\dots ,x_n) \in R^n \\
S(x)=(x_{\sigma _1},\dots ,x_{\sigma _n}) \ \ \ \ \ \ x_{\sigma _1}\leq x_{\sigma_2}\leq \dots \leq x_{\sigma _n} \\
R(x)=(rank(x_1),\dots ,rank(x_n)) \ \ (=\sigma^{-1})

For example, $ S ((3.4, 2.3, -1)) = (-1,2.3,3.4) $, $ R ((3.4, 2.3, -1)) = (3,2,1) $.

In this article, as "differentiation" of "sort"

\ frac {\ partial S (x)} {\ partial x_i} and \ frac {\ partial R (x)} {\ partial x_j}

Let's think about (something like).

Application of sort differentiation

What would be nice to be able to "differentiate" the sort?

The biggest advantage is that you can learn end-to-end type tasks that sort the output of Neural Network and solve something as mentioned in the above paper.

For example

--When multiple images with numbers are input, the ids of the images are returned in descending order of numbers. --Give a recommendation score to the product and learn the recommendation model by directly differentiating the ranking loss. --Use differentiable beam search decoders for NLP tasks

Such a problem can be considered. In addition, if the sort can be differentiated, the quantile function can also be differentiated as described later, so the task of "minimizing the n% value of the regression error" is the gradient descent method that directly differentiates the objective function. You will be able to solve it using.

symbol

1. Sorting, transportation problems and differentiation

By the way, $ S (x) and R (x) $ are not differentiable as they are. Furthermore, for example, if $ x $ is randomly sampled from $ R ^ n $, $ rank (x_1) $ should not change by slightly increasing or decreasing $ x_1 $. In other words, even if "something like differentiation" can be defined, it will always be 0.

To solve these problems and calculate the derivative of the sort, take the following steps:

  1. Consider sorting as a kind of transportation problem.
  2. Consider the problem of adding a regularization term to this transportation problem. (There is only one optimal solution.)
  3. For transport problems with regularization terms, there is an approximation algorithm (Shinkhorn algorithm) that finds the optimum solution in a differentiable form.
  4. Solve 2 with the Shinkhorn algorithm and differentiate the output.

The explanations are given below in order.

Sorting and shipping issues

Transportation problems

As the name implies, the transportation problem is the problem of determining the optimal way to transport goods from multiple factories to multiple stores. Delivery between each factory and store costs according to the shipping cost, and each transportation volume is determined so as to minimize the total shipping cost.

Written in mathematical formulas, given $ a \ in R ^ n_ +, b \ in R ^ m_ +, C \ in R ^ {n \ times m} _ + $, $ \ langle P, C \ rangle Finding $ P \ in U (a, b) $ that minimizes $. In this article

L_C(a,b)=min_{P\in U(a,b)}\langle P,C\rangle

Write.

$ a $ is the supply amount of the factory, $ b $ is the demand amount of the store, $ C_ {i, j} $ is the transportation cost per unit amount between the factory $ i $ store $ j $, $ P_ {i, j } $ Is the amount of transportation between the factory $ i $ store $ j $, and $ \ angle P, C \ rangle $ is the total cost of transportation. The total cost for optimal shipping is $ L_C (a, b) $.

Dealing with special shipping issues and sorting

Now consider the following simple transportation problem.

--The number of factories and stores is the same, $ n $ pieces --Factory and store are lined up in a straight line -Set $ x, y \ in R _n $ as the coordinates of the factory and store, respectively. --Suppose the location of each store satisfies the relationship $ y _1 <y _2 <\ dots <y _n . ( y \ in O ^ n $) --The delivery cost increases as the distance between factory stores increases. Multiply $ C _ {i, j} = h (y _j -x _i) $ using a differentiable non-negative narrow-sense convex function $ h $. --Factory supply and store demand are all the same, $ 1 / n $, that is, $ a = b = 1 _n / n $

At this time, the following proposition that connects sorting and transportation problems holds.

Proposition 1.

Under the above situation, one of the optimal solutions for the transportation problem $ L_C (a, b) = min_ {P \ in U (a, b)} \ langle P, C \ rangle $ is $ P _ * $. To do. At this time, the following holds.

R(x)=n^2 P_* \hat{b} \\
S(x)=n P_*^T x \\

Where $ \ hat {b} = (b_1, b_1 + b_2, \ dots, \ sum b) ^ T = (1 / n, 2 / n, \ dots, 1) ^ T $


In fact, consider the following transportation problem:

Factory id Factory coordinates Supply Store id Store coordinates Demand
1 2 1/3 a 0 1/3
2 1 1/3 b 1 1/3
3 0 1/3 c 2 1/3

Transportation cost (= square of distance)

factory\Store a b c
1 4 1 0
2 1 0 1
3 0 1 4

The optimum transport volume should be:

factory\Store a b c
1 0 0 1/3
2 0 1/3 0
3 1/3 0 0

Let's substitute these into the formula of the proposition.

3^2 \left(
    \begin{array}{ccc}
      0 & 0 & 1/3 \\
      0 & 1/3 & 0 \\
      1/3 & 0 & 0
    \end{array}
  \right)
  \left(
    \begin{array}{ccc}
      1/3   \\
      2/3   \\
      1 
    \end{array}
  \right) = 
  \left(
    \begin{array}{ccc}
      3   \\
      2   \\
      1 
    \end{array}
  \right) = R(
\left(
    \begin{array}{ccc}
      2   \\
      1   \\
      0 
    \end{array}
  \right)
  )
3 \left(
    \begin{array}{ccc}
      0 & 0 & 1/3 \\
      0 & 1/3 & 0 \\
      1/3 & 0 & 0
    \end{array}
  \right)
  \left(
    \begin{array}{ccc}
      2   \\
      1   \\
      0 
    \end{array}
  \right) = 
  \left(
    \begin{array}{ccc}
      0   \\
      1   \\
      2 
    \end{array}
  \right) = S(
\left(
    \begin{array}{ccc}
      2   \\
      1   \\
      0 
    \end{array}
  \right)
  )

You can see that the formula of the proposition holds.

Transport problems and differentiation

In the previous chapter, we confirmed that $ S (x) and R (x) $ can be written down by using the solution of a special transportation problem. Therefore, if the derivative of the solution of the transportation problem ($ P _ \ * $ of Proposition 1) by $ C _ {i, j} $ can be calculated, it depends on $ x _i $ of $ S (x), R (x) $. The derivative can also be calculated. This $ P _ \ * $ itself is not differentiable, but there is a way to find an approximate solution of $ P _ \ * $ in a differentiable way.

That is, first consider the following transportation problem with a "regularization term". In other words, the entropy of transportation volume

H(P)=-\sum_{i,j}P_{i,j}(log(P_{i,j})-1)

As instead of the original problem

L_C^{\epsilon}(a,b)=min_{P\in U(a,b)}\langle P,C\rangle - \epsilon H(P) ★

Think about. At $ \ epsilon \ to 0 $, the solution of this regularized transport problem converges to the solution of the original transport problem.

In addition, the following Shinkhorn algorithm can be used to find an approximate solution in a differentiable manner.

Shinkhorn algorithm

init u=u^0,v=v^0,l=0, calc K;

while l < MAX_ITER: \ \ \ \ u=a/(Kv) \ \ \ \ v=b/(K^Tu) \ \ \ \ l++

P=diag(u)Kdiag(v)

return P


MAX_ITER $ \ to \ infinty $ will cause the output of the Shinkhorn algorithm to converge to the optimal solution of ★. Regarding this Shinkhorn algorithm

Differentiate the transport problem

I explained in detail in.

Implementation by PyTorch

Let's implement the Shinkhorn algorithm in PyTorch and use it to calculate the derivative of the sort.

import torch
from torch import nn

#Shinkhorn algorithm
class OTLayer(nn.Module):
    def __init__(self, epsilon):
        super(OTLayer,self).__init__()
        self.epsilon = epsilon

    def forward(self, C, a, b, L):
        K = torch.exp(-C/self.epsilon)
        u = torch.ones_like(a)
        v = torch.ones_like(b)
        l = 0
        while l < L:
            u = a / torch.mv(K,v)
            v = b / torch.mv(torch.t(K),u)
            l += 1
                
        return u, v, u.view(-1,1)*(K * v.view(1,-1))

# sort & rank
class SortLayer(nn.Module):
    def __init__(self, epsilon):
        super(SortLayer,self).__init__()
        self.ot = OTLayer(epsilon)
        
    def forward(self, x, L):
        l = x.shape[0]
        y = (x.min() + (torch.arange(l, dtype=torch.float) * x.max() / l)).detach()
        C = ( y.repeat((l,1)) - torch.t(x.repeat((l,1))) ) **2
        a = torch.ones_like(x) / l
        b = torch.ones_like(y) / l
        _, _, P = self.ot(C, a, b, L)
        
        b_hat = torch.cumsum(b, dim=0)
        
        return l**2 * torch.mv(P, b_hat), l * torch.mv(torch.t(P), x)
sl = SortLayer(0.1)
x = torch.tensor([2., 8., 1.], requires_grad=True)
r, s = sl(x, 10)
print(r,s)
tensor([2.0500, 3.0000, 0.9500], grad_fn=<MulBackward0>) tensor([1.0500, 2.0000, 8.0000], grad_fn=<MulBackward0>)

(Calculation of differentiation)

r[0].backward()
print(x.grad)
tensor([ 6.5792e-06,  0.0000e+00, -1.1853e-20])

2. Sort generalization and quantile function

Generalization of sorting

In the previous chapter, we saw that the sort function and rank functions $ S (x) and R (x) $ can be written down using the solution of a special transport problem. The special shipping issue that corresponds to this sort is

--The number of factories and the number of stores are the same ――The supply amount of the factory and the demand amount of the store are all the same

I was saying. However, in general, transportation problems can be considered regardless of the number of factories and stores, or the supply and demand of different factories and stores. Considering $ S (x), R (x) $ corresponding to the solutions of these general transportation problems, we should be able to sort and generalize ranks.

Based on this idea, in Differentiable Ranks and Sorting using Optimal Transport, K (Kantorovich) is a generalized sorting function and rank function as follows. ) Sorting and K rank were introduced.

Definition 1. K sort and K rank

Transport problems for any $ x \ in R ^ n, y \ in O ^ n, a \ in \ Sigma_n, b \ in \ Sigma_m $ and the narrow convex function $ h $

L_C(a,b)=min_{P\in U(a,b)}\langle P,C\rangle,   C _{i,j}=h(y _j - x _i)

Let $ P _ * $ be one of the optimal solutions for. At this time, the K sort and K rank of $ x $ are defined as follows.

\hat{S} (a,x,b,y)=diag(b^{-1})P^T _* x \\
\hat{R} (a,x,b,y)=n* diag(a^{-1})P_* \hat{b}

Here, $ a ^ {-1} and b ^ {-1} $ represent the reciprocals of each element of the vector.


$ P_ * \ in U (a, b) $

P_* 1_n = a \\
P^T_* 1_m = b

Meet. Therefore, $ diag (b ^ {-1}) P ^ T _ \ * $ and $ diag (a ^ {-1}) P of $ \ hat S, \ hat R $ The _ \ * $ part can be regarded as normalized so that the sum of the lines of $ P ^ T _ \ *, P _ \ * $ is all 1. Also, if $ a = b = 1 _n / n $, the original $ S, R $ will be obtained, which is a natural extension of the sort function and rank function.

Also note that when $ b = 1_n / n $, $ n \ hat b $ is $ (1,2, \ dots, n) $ in order of "rank". Therefore, $ \ hat b $ corresponding to general $ b $ can be regarded as a generalization of "rank" and takes a real value. Taking this into account

-$ \ hat S _i $ is a linear combination of $ x _i $ assigned to the $ i $ th rank. -$ \ hat R _j $ is a linear combination of generalized ranks (rank = $ \ hat b $) to which $ x _j $ is assigned

You can also confirm that it is.

I tried to show the difference between normal sort, rank and K sort, and K rank in the schematic diagram below. The K sort and K rank represent the case of $ a = 1 _5 / 5, b = (0.48, 0.16, 0.36) $.

KソートとKランク

quantile function

What is the quantile function?

q (x, \ tau) = x \ tau \% point

Is the function $ q $. In fact, you can use the K sort $ \ hat S $ to efficiently calculate the quantile function. For example, $ x \ in R ^ n $ is a set of data points, and K sort in the situation where $ a = 1_n / n, y = (0,1 / 2,1), b = (0.29,0.02,0.69) $. Consider $ \ hat S (a, x, b, y) $.

quan2.png

As shown in the figure above, you can imagine that the lower 30% of x is associated with $ y_1 $, the upper 60% is associated with $ y_3 $, and only the points near the 30% point of x are associated with $ y_2 $. ..

Therefore, the quantile function uses K sort and a reasonably small value t,

q(x,\tau ;t)=\hat S (1_n/n, x, (\tau /100 -t/2,t,1-\tau /100-t/2),(0,1/2,1))[2]

You can expect to call. As explained in the previous chapter, the Shinkhorn algorithm can be used to obtain a differentiable approximation of $ \ hat S $, so the derivative of the quantile function can also be calculated.

Also note that the K sort above can be calculated as $ O (nl) $ ($ l $ is the iteration number of the Shinkhorn algorithm).

Implementation by PyTorch

Let's actually implement K sort in PyTorch and calculate the quantile function. Use the OTLyer from the previous chapter.

class QuantileLayer(nn.Module):
    def __init__(self, epsilon):
        super(QuantileLayer,self).__init__()
        self.ot = OTLayer(epsilon)
        self.y = torch.Tensor([0.,0.5,1.])
        
    def forward(self, x, tau, t, L):
        l = x.shape[0]
       
        C = ( self.y.repeat((l,1)) - torch.t(x.repeat((3,1))) ) **2
        a = torch.ones_like(x) / l
        b = torch.Tensor([tau-t/2, t, 1-tau-t/2])
        
        _, _, P = self.ot(C, a, b, L)
        
        b_hat = torch.cumsum(b, dim=0)
        
        return (torch.mv(torch.t(P), x) / b)[1]

(Ask for 30% points)

import numpy as np

np.random.seed(47)

x = np.random.rand(1000)

quantile = QuantileLayer(0.1)
print(quantile(torch.tensor(x,dtype=torch.float), 0.3, 0.1, 10))
tensor(0.3338)

3. Application of sort differentiation and generalization

As a simple application of sort differentiation and generalization to machine learning, we will solve the task of least quantile regression by directly differentiating the objective function using the quantile function and executing the gradient method.

least quantile regression Whereas ordinary linear regression optimizes the model to minimize the "mean" of the prediction error, least quantile regression trains the model to minimize the "n% value" of the error. This also includes the task of minimizing the median, which is the "50%" of the error. For example

--The noise of the data is not Gaussian distribution, and the teacher value deviates greatly in a specific direction.

This is useful when you want to minimize the "median" of the error by ignoring outlier data.

In this article, we will experiment with the least quantile regression using annual income data that is often used to explain the difference between "mean" and "median." Let's create a model that predicts annual income from age using data on average age and average annual income of listed companies. The model is trained by performing the gradient descent method by directly differentiating the "median error" using the method described in this article.

I used the data published by yutakikuchi on github.

Experiment

Companies with an average age of 45 or younger are extracted from the data and used. (To make it easier to observe the difference from the more normal linear regression.)

The distribution of the data is as follows. データの分布

It seems that age and annual income are roughly proportional, but it can be confirmed that the noise is not normally distributed and the dispersion spreads in the direction of larger annual income. You can expect a normal linear regression model to be dragged by outliers and overestimate the gradient.

The code for learning a linear model by directly differentiating the "median" (= 50% point) of the error using the method developed in this article is as follows. (The data formatting part is omitted.)

(Layer that executes the Sinkhorn algorithm)

import torch
from torch import nn
import torch.optim as optim


class OTLayer(nn.Module):
    def __init__(self, epsilon):
        super(OTLayer,self).__init__()
        self.epsilon = epsilon

    def forward(self, 
                C, # batch_size, n, m 
                a, # batch_size, n
                b, # batch_size, m
                L):
        bs = C.shape[0]
        K = torch.exp(-C/self.epsilon) # batch_size, n, m
        u = torch.ones_like(a).view(bs,-1,1) # batch_size, n, 1
        v = torch.ones_like(b).view(bs,-1,1) # batch_size, m, 1
        l = 0
            u = a.view(bs,-1,1) / (torch.bmm(K,v) + 1e-8) # batch_size, n, 1
            v = b.view(bs,-1,1) / (torch.bmm(torch.transpose(K, 1, 2),u) + 1e-8)# batch_size, m, 1
            l += 1
                
        return u * K * v.view(bs,1,-1) # batch_size, n, m

(layer that calculates the quantile function)

class QuantileLayer(nn.Module):
    def __init__(self, epsilon, y):
        super(QuantileLayer,self).__init__()
        self.ot = OTLayer(epsilon)
        self.y = y.detach()
        
    def forward(self, x, # batch_size, seq_len
                tau, t, L):
        bs = x.shape[0]
        seq_len = x.shape[1]
        C = ( self.y.repeat((bs,seq_len,1)) - x.unsqueeze(-1).expand(bs,seq_len,3) ) **2 # batch_size, seq_len, 3
        
        a = torch.ones_like(x) / seq_len  # batch_size, seq_len
        b = torch.Tensor([tau-t/2, t, 1-tau-t/2]).expand([bs, 3]) # batch_size, 3
        
        P = self.ot(C, a, b, L) # batch_size, seq_len, 3
        
        k_sort =  torch.bmm(
            torch.transpose(P,1,2), # batch_size, 3, seq_len 
            x.unsqueeze(-1) # batch_size, seq_len, 1
        ).view(-1) / b.view(-1) # 3, 

        return k_sort[1]

(Data preparation for age = age and annual income = income)

import pandas as pd

data = pd.DataFrame({"age":ages, "income":incomes})
data_2 = data[data.age <= 45]
ppd_data = (data_2- data_2.mean(axis=0))/data_2.std(axis=0)

X = torch.Tensor(ppd_data.age.values.reshape(-1,1))
ans =  torch.Tensor(ppd_data.income.values.reshape(-1,1))

(Learning execution)


model = nn.Linear(1,1)
loss = nn.MSELoss(reduction='none')

y = [0, ppd_data.income.max()/4., ppd_data.income.max()/2.]
quantile = QuantileLayer(0.1, torch.Tensor(y))

optimizer = optim.Adam(model.parameters(), lr=0.1)

MAX_ITER = 100
for i in range(MAX_ITER):
    optimizer.zero_grad()
    
    pred = model(X)
    loss_value = loss(pred, ans).view(1,-1) # 1, seq_len( = data size)
    
    #Calculate median error
    quantile_loss = quantile(loss_value, 0.5, 0.1, 10)
    print(quantile_loss)
    quantile_loss.backward()
    optimizer.step()


Here's a comparison of the fitting results of the trained model with a regular regression model that minimizes the "mean" of the error: It can be confirmed that the model with the median minimized (= quantile) is less likely to be dragged by the outlier data.

誤差の中央値を最小化したモデルと誤差の平均値を最小化したモデルの比較

4. Summary

In this article, I explained how to differentiate the sort and how to calculate the quantile function in a differentiable form as a generalization of the sort. As a simple application, we solved the task called last quantile regression by the gradient method and observed the difference between the model that minimizes the "mean" and the model that minimizes the "median" of the error.

As introduced at the beginning, sort differentiation is expected to have a wider range of applications such as beam search differentiation, and may be seen in various papers in the future.

Recommended Posts

Differentiation of sort and generalization of sort
[Django 2.2] Sort and get the value of the relation destination
Problems of liars and honesty
Mechanism of pyenv and virtualenv
Pre-processing and post-processing of pytest
Sample program and execution example of ensemble learning (Stacked generalization)
Combination of recursion and generator
Combination of anyenv and direnv
Explanation and implementation of SocialFoceModel
Insertion sort of AOJ exercises
Coexistence of pyenv and autojump
Use and integration of "Shodan"
Problems of liars and honesty
Occurrence and resolution of tensorflow.python.framework.errors_impl.FailedPreconditionError
Comparison of Apex and Lamvery
Source installation and installation of Python
Introduction and tips of mlflow.Tracking
About Python sort () and reverse ()
Story of speed comparison of sort of numerical value and character string (unfinished))
How to sort 2D arrays, dictionaries and lists of proprietary classes
Various of Tweepy. Ma ♡ and ♡ me ♡
Basic knowledge of Linux and basic commands
Order of arguments of RegularGridInterpolator and interp2d
The story of Python and the story of NaN
Explanation and implementation of PRML Chapter 4
Benefits and examples of using RabbitMq
Explanation and implementation of ESIM algorithm
Danger of mixing! ndarray and matrix
Installation of SciPy and matplotlib (Python)
Significance of machine learning and mini-batch learning
Introduction and implementation of activation function
Memorandum of saving and loading model
Misunderstandings and interpretations of Luigi's dependencies
Explanation and implementation of simple perceptron
Calculation of homebrew class and existing class
Difference between sort and sorted (memorial)
This and that of python properties
Design of experiments and combinatorial optimization
Installation and easy usage of pytest
Differentiation of time series data (discrete)
Clash of Clans and image analysis (3)
Features of symbolic and hard links
Coexistence of Python2 and 3 with CircleCI (1.0)
Summary of Python indexes and slices
Aggregation and visualization of accumulated numbers
Reputation of Python books and reference books