This article describes how to "differentiate" a "sort" and generalize the sort. By considering the generalization of sorting, we can also calculate the "derivative" of the quantile function.
In addition, we will actually experiment with a machine learning task called least quantile regression using the differentiation and generalization of sorts.
This is an introduction to the content of the next paper recently announced by Google and attracting attention.
Differentiable Ranks and Sorting using Optimal Transport
Since it uses a technique similar to Article that differentiates transportation problems, some explanations will be given.
When we use sort, we often use it in the form of two functions, sort function $ S (x) $ and rank function $ R (x) $.
x=(x_1,\dots ,x_n) \in R^n \\
S(x)=(x_{\sigma _1},\dots ,x_{\sigma _n}) \ \ \ \ \ \ x_{\sigma _1}\leq x_{\sigma_2}\leq \dots \leq x_{\sigma _n} \\
R(x)=(rank(x_1),\dots ,rank(x_n)) \ \ (=\sigma^{-1})
For example, $ S ((3.4, 2.3, -1)) = (-1,2.3,3.4) $, $ R ((3.4, 2.3, -1)) = (3,2,1) $.
In this article, as "differentiation" of "sort"
Let's think about (something like).
What would be nice to be able to "differentiate" the sort?
The biggest advantage is that you can learn end-to-end type tasks that sort the output of Neural Network and solve something as mentioned in the above paper.
For example
--When multiple images with numbers are input, the ids of the images are returned in descending order of numbers. --Give a recommendation score to the product and learn the recommendation model by directly differentiating the ranking loss. --Use differentiable beam search decoders for NLP tasks
Such a problem can be considered. In addition, if the sort can be differentiated, the quantile function can also be differentiated as described later, so the task of "minimizing the n% value of the regression error" is the gradient descent method that directly differentiates the objective function. You will be able to solve it using.
By the way, $ S (x) and R (x) $ are not differentiable as they are. Furthermore, for example, if $ x $ is randomly sampled from $ R ^ n $, $ rank (x_1) $ should not change by slightly increasing or decreasing $ x_1 $. In other words, even if "something like differentiation" can be defined, it will always be 0.
To solve these problems and calculate the derivative of the sort, take the following steps:
The explanations are given below in order.
As the name implies, the transportation problem is the problem of determining the optimal way to transport goods from multiple factories to multiple stores. Delivery between each factory and store costs according to the shipping cost, and each transportation volume is determined so as to minimize the total shipping cost.
Written in mathematical formulas, given $ a \ in R ^ n_ +, b \ in R ^ m_ +, C \ in R ^ {n \ times m} _ + $, $ \ langle P, C \ rangle Finding $ P \ in U (a, b) $ that minimizes $. In this article
Write.
$ a $ is the supply amount of the factory, $ b $ is the demand amount of the store, $ C_ {i, j} $ is the transportation cost per unit amount between the factory $ i $ store $ j $, $ P_ {i, j } $ Is the amount of transportation between the factory $ i $ store $ j $, and $ \ angle P, C \ rangle $ is the total cost of transportation. The total cost for optimal shipping is $ L_C (a, b) $.
Now consider the following simple transportation problem.
--The number of factories and stores is the same, $ n $ pieces
--Factory and store are lined up in a straight line
-Set $ x, y \ in R _n $ as the coordinates of the factory and store, respectively.
--Suppose the location of each store satisfies the relationship $ y _1 <y _2 <\ dots <y _n
At this time, the following proposition that connects sorting and transportation problems holds.
Under the above situation, one of the optimal solutions for the transportation problem $ L_C (a, b) = min_ {P \ in U (a, b)} \ langle P, C \ rangle $ is $ P _ * $. To do. At this time, the following holds.
R(x)=n^2 P_* \hat{b} \\
S(x)=n P_*^T x \\
Where $ \ hat {b} = (b_1, b_1 + b_2, \ dots, \ sum b) ^ T = (1 / n, 2 / n, \ dots, 1) ^ T $
In fact, consider the following transportation problem:
Factory id | Factory coordinates | Supply | Store id | Store coordinates | Demand | |
---|---|---|---|---|---|---|
1 | 2 | 1/3 | a | 0 | 1/3 | |
2 | 1 | 1/3 | b | 1 | 1/3 | |
3 | 0 | 1/3 | c | 2 | 1/3 |
Transportation cost (= square of distance)
factory\Store | a | b | c |
---|---|---|---|
1 | 4 | 1 | 0 |
2 | 1 | 0 | 1 |
3 | 0 | 1 | 4 |
The optimum transport volume should be:
factory\Store | a | b | c |
---|---|---|---|
1 | 0 | 0 | 1/3 |
2 | 0 | 1/3 | 0 |
3 | 1/3 | 0 | 0 |
Let's substitute these into the formula of the proposition.
3^2 \left(
\begin{array}{ccc}
0 & 0 & 1/3 \\
0 & 1/3 & 0 \\
1/3 & 0 & 0
\end{array}
\right)
\left(
\begin{array}{ccc}
1/3 \\
2/3 \\
1
\end{array}
\right) =
\left(
\begin{array}{ccc}
3 \\
2 \\
1
\end{array}
\right) = R(
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right)
)
3 \left(
\begin{array}{ccc}
0 & 0 & 1/3 \\
0 & 1/3 & 0 \\
1/3 & 0 & 0
\end{array}
\right)
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right) =
\left(
\begin{array}{ccc}
0 \\
1 \\
2
\end{array}
\right) = S(
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right)
)
You can see that the formula of the proposition holds.
In the previous chapter, we confirmed that $ S (x) and R (x) $ can be written down by using the solution of a special transportation problem. Therefore, if the derivative of the solution of the transportation problem ($ P _ \ * $ of Proposition 1) by $ C _ {i, j} $ can be calculated, it depends on $ x _i $ of $ S (x), R (x) $. The derivative can also be calculated. This $ P _ \ * $ itself is not differentiable, but there is a way to find an approximate solution of $ P _ \ * $ in a differentiable way.
That is, first consider the following transportation problem with a "regularization term". In other words, the entropy of transportation volume
As instead of the original problem
Think about. At $ \ epsilon \ to 0 $, the solution of this regularized transport problem converges to the solution of the original transport problem.
In addition, the following Shinkhorn algorithm can be used to find an approximate solution in a differentiable manner.
init
while
return
MAX_ITER $ \ to \ infinty $ will cause the output of the Shinkhorn algorithm to converge to the optimal solution of ★. Regarding this Shinkhorn algorithm
Differentiate the transport problem
I explained in detail in.
Let's implement the Shinkhorn algorithm in PyTorch and use it to calculate the derivative of the sort.
import torch
from torch import nn
#Shinkhorn algorithm
class OTLayer(nn.Module):
def __init__(self, epsilon):
super(OTLayer,self).__init__()
self.epsilon = epsilon
def forward(self, C, a, b, L):
K = torch.exp(-C/self.epsilon)
u = torch.ones_like(a)
v = torch.ones_like(b)
l = 0
while l < L:
u = a / torch.mv(K,v)
v = b / torch.mv(torch.t(K),u)
l += 1
return u, v, u.view(-1,1)*(K * v.view(1,-1))
# sort & rank
class SortLayer(nn.Module):
def __init__(self, epsilon):
super(SortLayer,self).__init__()
self.ot = OTLayer(epsilon)
def forward(self, x, L):
l = x.shape[0]
y = (x.min() + (torch.arange(l, dtype=torch.float) * x.max() / l)).detach()
C = ( y.repeat((l,1)) - torch.t(x.repeat((l,1))) ) **2
a = torch.ones_like(x) / l
b = torch.ones_like(y) / l
_, _, P = self.ot(C, a, b, L)
b_hat = torch.cumsum(b, dim=0)
return l**2 * torch.mv(P, b_hat), l * torch.mv(torch.t(P), x)
sl = SortLayer(0.1)
x = torch.tensor([2., 8., 1.], requires_grad=True)
r, s = sl(x, 10)
print(r,s)
tensor([2.0500, 3.0000, 0.9500], grad_fn=<MulBackward0>) tensor([1.0500, 2.0000, 8.0000], grad_fn=<MulBackward0>)
(Calculation of differentiation)
r[0].backward()
print(x.grad)
tensor([ 6.5792e-06, 0.0000e+00, -1.1853e-20])
In the previous chapter, we saw that the sort function and rank functions $ S (x) and R (x) $ can be written down using the solution of a special transport problem. The special shipping issue that corresponds to this sort is
--The number of factories and the number of stores are the same ――The supply amount of the factory and the demand amount of the store are all the same
I was saying. However, in general, transportation problems can be considered regardless of the number of factories and stores, or the supply and demand of different factories and stores. Considering $ S (x), R (x) $ corresponding to the solutions of these general transportation problems, we should be able to sort and generalize ranks.
Based on this idea, in Differentiable Ranks and Sorting using Optimal Transport, K (Kantorovich) is a generalized sorting function and rank function as follows. ) Sorting and K rank were introduced.
Transport problems for any $ x \ in R ^ n, y \ in O ^ n, a \ in \ Sigma_n, b \ in \ Sigma_m $ and the narrow convex function $ h $
Let $ P _ * $ be one of the optimal solutions for. At this time, the K sort and K rank of $ x $ are defined as follows.
\hat{S} (a,x,b,y)=diag(b^{-1})P^T _* x \\
\hat{R} (a,x,b,y)=n* diag(a^{-1})P_* \hat{b}
Here, $ a ^ {-1} and b ^ {-1} $ represent the reciprocals of each element of the vector.
$ P_ * \ in U (a, b) $
P_* 1_n = a \\
P^T_* 1_m = b
Meet. Therefore, $ diag (b ^ {-1}) P ^ T _ \ * $ and $ diag (a ^ {-1}) P of $ \ hat S, \ hat R $ The _ \ * $ part can be regarded as normalized so that the sum of the lines of $ P ^ T _ \ *, P _ \ * $ is all 1. Also, if $ a = b = 1 _n / n $, the original $ S, R $ will be obtained, which is a natural extension of the sort function and rank function.
Also note that when $ b = 1_n / n $, $ n \ hat b $ is $ (1,2, \ dots, n) $ in order of "rank". Therefore, $ \ hat b $ corresponding to general $ b $ can be regarded as a generalization of "rank" and takes a real value. Taking this into account
-$ \ hat S _i $ is a linear combination of $ x _i $ assigned to the $ i $ th rank. -$ \ hat R _j $ is a linear combination of generalized ranks (rank = $ \ hat b $) to which $ x _j $ is assigned
You can also confirm that it is.
I tried to show the difference between normal sort, rank and K sort, and K rank in the schematic diagram below. The K sort and K rank represent the case of $ a = 1 _5 / 5, b = (0.48, 0.16, 0.36) $.
What is the quantile function?
Is the function $ q $. In fact, you can use the K sort $ \ hat S $ to efficiently calculate the quantile function. For example, $ x \ in R ^ n $ is a set of data points, and K sort in the situation where $ a = 1_n / n, y = (0,1 / 2,1), b = (0.29,0.02,0.69) $. Consider $ \ hat S (a, x, b, y) $.
As shown in the figure above, you can imagine that the lower 30% of x is associated with $ y_1 $, the upper 60% is associated with $ y_3 $, and only the points near the 30% point of x are associated with $ y_2 $. ..
Therefore, the quantile function uses K sort and a reasonably small value t,
You can expect to call. As explained in the previous chapter, the Shinkhorn algorithm can be used to obtain a differentiable approximation of $ \ hat S $, so the derivative of the quantile function can also be calculated.
Also note that the K sort above can be calculated as $ O (nl) $ ($ l $ is the iteration number of the Shinkhorn algorithm).
Let's actually implement K sort in PyTorch and calculate the quantile function. Use the OTLyer from the previous chapter.
class QuantileLayer(nn.Module):
def __init__(self, epsilon):
super(QuantileLayer,self).__init__()
self.ot = OTLayer(epsilon)
self.y = torch.Tensor([0.,0.5,1.])
def forward(self, x, tau, t, L):
l = x.shape[0]
C = ( self.y.repeat((l,1)) - torch.t(x.repeat((3,1))) ) **2
a = torch.ones_like(x) / l
b = torch.Tensor([tau-t/2, t, 1-tau-t/2])
_, _, P = self.ot(C, a, b, L)
b_hat = torch.cumsum(b, dim=0)
return (torch.mv(torch.t(P), x) / b)[1]
(Ask for 30% points)
import numpy as np
np.random.seed(47)
x = np.random.rand(1000)
quantile = QuantileLayer(0.1)
print(quantile(torch.tensor(x,dtype=torch.float), 0.3, 0.1, 10))
tensor(0.3338)
As a simple application of sort differentiation and generalization to machine learning, we will solve the task of least quantile regression by directly differentiating the objective function using the quantile function and executing the gradient method.
least quantile regression Whereas ordinary linear regression optimizes the model to minimize the "mean" of the prediction error, least quantile regression trains the model to minimize the "n% value" of the error. This also includes the task of minimizing the median, which is the "50%" of the error. For example
--The noise of the data is not Gaussian distribution, and the teacher value deviates greatly in a specific direction.
This is useful when you want to minimize the "median" of the error by ignoring outlier data.
In this article, we will experiment with the least quantile regression using annual income data that is often used to explain the difference between "mean" and "median." Let's create a model that predicts annual income from age using data on average age and average annual income of listed companies. The model is trained by performing the gradient descent method by directly differentiating the "median error" using the method described in this article.
I used the data published by yutakikuchi on github.
Companies with an average age of 45 or younger are extracted from the data and used. (To make it easier to observe the difference from the more normal linear regression.)
The distribution of the data is as follows.
It seems that age and annual income are roughly proportional, but it can be confirmed that the noise is not normally distributed and the dispersion spreads in the direction of larger annual income. You can expect a normal linear regression model to be dragged by outliers and overestimate the gradient.
The code for learning a linear model by directly differentiating the "median" (= 50% point) of the error using the method developed in this article is as follows. (The data formatting part is omitted.)
(Layer that executes the Sinkhorn algorithm)
import torch
from torch import nn
import torch.optim as optim
class OTLayer(nn.Module):
def __init__(self, epsilon):
super(OTLayer,self).__init__()
self.epsilon = epsilon
def forward(self,
C, # batch_size, n, m
a, # batch_size, n
b, # batch_size, m
L):
bs = C.shape[0]
K = torch.exp(-C/self.epsilon) # batch_size, n, m
u = torch.ones_like(a).view(bs,-1,1) # batch_size, n, 1
v = torch.ones_like(b).view(bs,-1,1) # batch_size, m, 1
l = 0
u = a.view(bs,-1,1) / (torch.bmm(K,v) + 1e-8) # batch_size, n, 1
v = b.view(bs,-1,1) / (torch.bmm(torch.transpose(K, 1, 2),u) + 1e-8)# batch_size, m, 1
l += 1
return u * K * v.view(bs,1,-1) # batch_size, n, m
(layer that calculates the quantile function)
class QuantileLayer(nn.Module):
def __init__(self, epsilon, y):
super(QuantileLayer,self).__init__()
self.ot = OTLayer(epsilon)
self.y = y.detach()
def forward(self, x, # batch_size, seq_len
tau, t, L):
bs = x.shape[0]
seq_len = x.shape[1]
C = ( self.y.repeat((bs,seq_len,1)) - x.unsqueeze(-1).expand(bs,seq_len,3) ) **2 # batch_size, seq_len, 3
a = torch.ones_like(x) / seq_len # batch_size, seq_len
b = torch.Tensor([tau-t/2, t, 1-tau-t/2]).expand([bs, 3]) # batch_size, 3
P = self.ot(C, a, b, L) # batch_size, seq_len, 3
k_sort = torch.bmm(
torch.transpose(P,1,2), # batch_size, 3, seq_len
x.unsqueeze(-1) # batch_size, seq_len, 1
).view(-1) / b.view(-1) # 3,
return k_sort[1]
(Data preparation for age = age and annual income = income)
import pandas as pd
data = pd.DataFrame({"age":ages, "income":incomes})
data_2 = data[data.age <= 45]
ppd_data = (data_2- data_2.mean(axis=0))/data_2.std(axis=0)
X = torch.Tensor(ppd_data.age.values.reshape(-1,1))
ans = torch.Tensor(ppd_data.income.values.reshape(-1,1))
(Learning execution)
model = nn.Linear(1,1)
loss = nn.MSELoss(reduction='none')
y = [0, ppd_data.income.max()/4., ppd_data.income.max()/2.]
quantile = QuantileLayer(0.1, torch.Tensor(y))
optimizer = optim.Adam(model.parameters(), lr=0.1)
MAX_ITER = 100
for i in range(MAX_ITER):
optimizer.zero_grad()
pred = model(X)
loss_value = loss(pred, ans).view(1,-1) # 1, seq_len( = data size)
#Calculate median error
quantile_loss = quantile(loss_value, 0.5, 0.1, 10)
print(quantile_loss)
quantile_loss.backward()
optimizer.step()
Here's a comparison of the fitting results of the trained model with a regular regression model that minimizes the "mean" of the error: It can be confirmed that the model with the median minimized (= quantile) is less likely to be dragged by the outlier data.
In this article, I explained how to differentiate the sort and how to calculate the quantile function in a differentiable form as a generalization of the sort. As a simple application, we solved the task called last quantile regression by the gradient method and observed the difference between the model that minimizes the "mean" and the model that minimizes the "median" of the error.
As introduced at the beginning, sort differentiation is expected to have a wider range of applications such as beam search differentiation, and may be seen in various papers in the future.
Recommended Posts