Consistency of scikit-learn API design

Introduction

scikit-learn is a convenient machine learning library for Python that you can easily use with Numpy, Scipy, and Matplotlib. scikit-learn becomes much easier to use once you know the API design pattern.

This time, I will explain the charm of scikit-learn based on the paper by the author of scikit-learn.

Basic design

Scikit-learn objects are designed according to some patterns to keep the API consistent. ** By understanding this pattern, you can use any object without any inconvenience. ** **

Estimator scikit-learn is based on an interface called Estimator. ** Estimator trains some model (parameter) based on the data. ** ** It always has a method called fit, and learning is performed by passing data to the argument of fit. You can also set the hyperparameters required for training (in the constructor or the set_params method).

A class called LogisticRegression that performs logistic regression is also one of the Estimators.

from sklearn.linear_model import LogisticRegession
clf = LogisticRegression(penalty="l1") #Hyperparameter settings
clf.fit(X_train, y_train) #Learn the model based on training data

Predictor Many Estimators also introduce Predictor interfaces at the same time. ** Predictor makes predictions (outputs) based on the model learned in fit. ** ** Passing data as an argument to the predict method returns a prediction. It also has a method called score, which allows you to evaluate the model by passing a dataset and a label.

For example, LogisticRegression is a Predictor, so you can use the predict and score methods without any problems.

clf.predict(X_test) #Predict about test data
clf.score(X_test, y_test) #Compare expectations with actual answers about test data

Transformer In addition to Predictor, there is also a class that introduces an interface called Transformer. As the name implies, ** Transformer can transform data. ** ** It is used more often in data processing APIs than in machine learning models. Returns the transformed data using a method called transform. Also, by using the method fit_transform, it is designed so that learning and transformation can be performed at the same time.

The example below implements a transformation with the StandardScaler that standardizes the dataset. For StandardScaler, we will learn the mean and variance of each trace rather than a complex model.

from sklearn.preprocessing import StandardScaler 
scaler = StandScaler()
X_train = scaler.fit_transform(X_train) #Learning / transforming training data
X_test = scaler.transform(X_test) #Deformation without learning test data (using the mean / variance of training data)

Also, the Predictor and Transformer interfaces can be installed at the same time.

Get parameters / hyperparameters

The hyperparameters you set and the learned parameters are saved in the object. (Names of learned parameters end with an underscore) For information on how to access parameters and hyperparameters, see "Attributes" in the documentation for each object.

Example: Get the mean and variance learned by StandardScaler

#Continue from the front
mean = scaler.mean_
variance = scaler.var_

Therefore, any Estimator

  1. Create an instance and set hyperparameters
  2. Learn with fit
  3. Check the learned parameters that achieve the purpose with predict, score, transform, etc. You can easily build a workflow by following the procedure. Implementation from data processing to model learning / evaluation can all be done using Estimator.

Applied design

Estimator synthesis

Since all Estimators have the same method, ** you can easily combine multiple Estimators. ** ** Use Pipeline for parallel processing and FeatureUnion for parallel processing.

For example, if you want to standardize your data and perform logistic regression, you can implement the process neatly by using a pipeline.

from sklearn.pipeline import Pipeline
pipe = Pipeline([
    {'std_scaler', StandardScaler()},
    {'log_reg', LogisticRegression()} #Receives transformed data of transformer
])
pipe.fit(X_train, y_train)
pipe.score(X_test, y_test)

Cross Validation is also an Estimator

scikit-learn allows you to validate hyperparameters using classes such as GridSearchCV and RandomSearchCV. These also introduce the Estimator interface and use fit to learn.

Example: Find the best hyperparameters for logistic regression using Grid Search

from sklearn.model_selection import GridSearchCV
clf = GridSearchCV(
    estimator=LogisticRegression(),
    param_grid={
        'C' = [1, 3, 10, 30, 100]
    }
)
clf.fit(X_train, y_train) # param_Train multiple models by applying hyperparameters in grid one by one
best_clf = clf.best_estimator_ #Get the best Estimator!

Make your own Estimator

By creating a class with methods defined in the interface, such as fit, ** you can easily use it for pipelines and validation. ** ** When creating an Estimator, inherit BaseEstimator, and when creating a Transformer etc., inherit an appropriate Mixin at the same time.

Transformer example:

from sklearn.base import BaseEstimator, TransformerMixin

class MyTransformer(BaseEstimator, TransformerMixin):
  def __init__(self, param_1, param_2):
    #Hyperparameter processing
    self.param_1 = param_1
    # ...

  def fit(self, X, y=None):
    #processing
    return self 

  def transform(self, X, y=None):
     #Numpy matrix processing
     # X = ...
    return X 

  # fit_transform is automatically implemented by Transformer Mixin

transformer = MyTransformer()
X_transformed = transformer.fit_transform(X_train)

Conclusion

Scikit-learn provides objects that implement various machine learning methods, but even if you do not understand the contents, you can use them all if you understand the design patterns of Estimator, Predictor, and Transformer. The scikit-learn API is attractive because it is highly consistent, and you can easily proceed with machine learning.

Recommended Posts

Consistency of scikit-learn API design
About max_iter of LogisticRegression () of scikit-learn
Parallel processing with Parallel of scikit-learn
python: Basics of using scikit-learn ①
Grid search of hyperparameters with Scikit-learn
[Translation] scikit-learn 0.18 Tutorial Table of Contents
Installation of scikit-learn (Mac OS X)
Introduction of data-driven controller design method
Design of experiments and combinatorial optimization