Examples and solutions that cannot be optimized well with scipy.optimize.least_squares

With scipy, you can use ʻoptimize.least_squares to fit the parameters of a nonlinear function to your data. However, depending on the form of the nonlinear function, it may not be possible to find the optimum parameters. This is because ʻoptimize.least_squares can only find local optimal solutions.

This time, I will give an example where ʻoptimize.least_squares falls into a locally optimal solution, and try to find a global optimal solution using ʻoptimize.basinhopping.

version:

Examples that cannot be optimized well

Consider the following function with $ a $ as a parameter.

y(x)=\frac{1}{100}(x-3a)(2x-a)(3x+a)(x+2a)

Suppose you get noisy data when $ a = 2 $.


import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

seed = 0
np.random.seed(seed)

def y(x, a):
    return (x-3.*a) * (2.*x-a) * (3.*x+a) * (x+2.*a) / 100.

a_orig = 2.
xs = np.linspace(-5, 7, 1000)
ys = y(xs,a_orig)

num_data = 30
data_x = np.random.uniform(-5, 5, num_data)
data_y = y(data_x, a_orig) + np.random.normal(0, 0.5, num_data)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(data_x, data_y, 'o', label='data')
plt.legend()

qiita_1.png

On the other hand, try to find the parameters with ʻoptimize.least_squares`.

from scipy.optimize import least_squares

def calc_residuals(params, data_x, data_y):
    model_y = y(data_x, params[0])
    return model_y - data_y

a_init = -3
res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_2.png

I set the initial value of the parameter to $ a_0 = -3 $, but it didn't fit the data well.

Why you can't optimize well

Looking at how the result changes depending on the initial value of the parameter,

a_inits = np.linspace(-4, 4, 1000)
a_fits = np.zeros(1000)
for i, a_init in enumerate(a_inits):
    res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))
    a_fits[i] = res.x[0]

plt.plot(a_inits, a_fits)
plt.xlabel("initial value")
plt.ylabel("optimized value")

qiita_3.png

If the initial value is negative, you have fallen into the locally optimal parameters. The reason for this can be seen by looking at the relationship between the parameter values and the residuals. As shown in the figure below, there are two minimum values for the parameter, so the result will change depending on the initial value.

def calc_cost(params, data_x, data_y):
    residuals = calc_residuals(params, data_x, data_y)
    return (residuals * residuals).sum()

costs = np.zeros(1000)
for i, a in enumerate(a_inits):
    costs[i] = calc_cost(np.array([a]), data_x, data_y)
plt.plot(a_inits, costs)
plt.xlabel("parameter")
plt.ylabel("sum of squares")

qiita_4.png

How to optimize well

In order to find the optimum parameters globally, it is sufficient to calculate from various initial values. There is ʻoptimize.basinhopping` in scipy as a way to do this nicely. Let's do it.

from scipy.optimize import basinhopping
a_init = -3.0
minimizer_kwargs = {"args":(data_x, data_y)}
res = basinhopping(calc_cost, np.array([a_init]),stepsize=2.,minimizer_kwargs=minimizer_kwargs)
print(res.x)

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit by basin-hopping a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_5.png

The parameters were successfully obtained. The trick is the argument stepsize. Determines how much this argument changes the initial value.

Recommended Posts

Examples and solutions that cannot be optimized well with scipy.optimize.least_squares
Items that cannot be imported with sklearn
Summary of examples that cannot be pyTorch backward
Import libraries that cannot be pip installed with PyCharm
Options when installing libraries that cannot be piped with pyenv
Dealing with the error that HTTP fetch error occurs in gpg and the key cannot be obtained
Convert GRIB2 format weather data that cannot be opened with pygrib to netCDF and visualize it
Parallel computing (pathos) when dealing with objects that cannot be pickled
Draw a graph that can be moved around with HoloViews and Bokeh
Measures that pip install cannot be done with pycharm or import ssl cannot be done
Address to the bug that node.surface cannot be obtained with python3 + mecab
matplotlib image cannot be renamed and saved
Python modules with "-(hyphen)" cannot be removed
About the matter that localhost: 4040 cannot be accessed after running Spark with Docker
Examples and solutions that the Python version specified in pyenv does not run