Bibliothèque GBDT: j'ai essayé la prédiction de la consommation de carburant (retour) avec CatBoost

Je ne connaissais pas CatBoost, qui est parfois utilisé avec XGBoost et lightGBM, qui sont des bibliothèques GBDT (Gradient Boosting Decision Trees), jusqu'à récemment, alors j'ai essayé de le déplacer avec une tâche de régression. C'était.

CatBoost?

Je vais coller le texte d'introduction du site officiel. (Google Traduction)

CatBoost est un algorithme d'amplification de gradient pour les arbres de décision.
Développé par des chercheurs et ingénieurs Yandex
Il est utilisé pour la recherche, les systèmes de recommandation, les assistants personnels, les voitures autonomes, les prévisions météorologiques et de nombreuses autres tâches telles que Yandex, CERN, Cloudflare, Careem taxi.

Articles que j'ai utilisés comme référence

L'ensemble de données utilisé cette fois

--Ensemble de données MPG automatique -Ceci est l'ensemble de données utilisé dans TensorFlow Tutorials. ―― Prévoyez la consommation de carburant de la voiture. Les variables explicatives comprennent le nombre de cylindres, le volume d'échappement, la puissance, le poids, etc.

Contenu

Le code suivant a fonctionné sur Google Colab. (CPU)

Installez CatBoost avec pip

!pip install catboost -U

Télécharger le jeu de données

import urllib.request

url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
file_path = './auto-mpg.data'
urllib.request.urlretrieve(url, file_path)

Prétraitement des données

import pandas as pd

column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
                   'Acceleration', 'Model Year', 'Origin'] 
dataset = pd.read_csv(file_path, names=column_names,
                      na_values = "?", comment='\t',
                      sep=" ", skipinitialspace=True)

#Cette fois le but est de bouger, donc nan va tomber
dataset = dataset.dropna().reset_index(drop=True)

#Variable de catégorie:L'origine est gérée par Cat Boost, alors faites-en un type String
dataset['Origin'] = dataset['Origin'].astype(str)

train_dataset = dataset.sample(frac=0.8,random_state=0)
test_dataset = dataset.drop(train_dataset.index)
train_labels = train_dataset.pop('MPG')
test_labels = test_dataset.pop('MPG')

Préparer un ensemble de données à utiliser par CatBoost

import numpy as np
from catboost import CatBoostRegressor, FeaturesData, Pool

def split_features(df):
    cfc = []
    nfc = []
    for column in df:
        if column == 'Origin':
            cfc.append(column)
        else:
            nfc.append(column)
    return df[cfc], df[nfc]

cat_train, num_train = split_features(train_dataset)
cat_test, num_test = split_features(test_dataset)

train_pool = Pool(
    data = FeaturesData(num_feature_data = np.array(num_train.values, dtype=np.float32), 
                    cat_feature_data = np.array(cat_train.values, dtype=object), 
                    num_feature_names = list(num_train.columns.values), 
                    cat_feature_names = list(cat_train.columns.values)),
    label =  np.array(train_labels, dtype=np.float32)
)

test_pool = Pool(
    data = FeaturesData(num_feature_data = np.array(num_test.values, dtype=np.float32), 
                    cat_feature_data = np.array(cat_test.values, dtype=object), 
                    num_feature_names = list(num_test.columns.values), 
                    cat_feature_names = list(cat_test.columns.values))
)

Apprentissage

model = CatBoostRegressor(iterations=2000, learning_rate=0.05, depth=5)
model.fit(train_pool)

Les paramètres ci-dessus sont les valeurs telles qu'elles figurent dans l'article de référence. À propos, l'apprentissage s'est terminé avec «total: 4,3 s».

Diagramme d'inférence / résultat

import matplotlib.pyplot as plt

preds = model.predict(test_pool)

xs = list(range(len(test_labels)))
plt.plot(xs, test_labels.values, color = 'r')
plt.plot(xs, preds, color = 'k');
plt.legend(['Target', 'Prediction'], loc = 'upper left');
plt.show()

Lorsqu'il est tracé, le résultat est le suivant. catboost_result.png

Impressions etc.

Cette fois, je viens de déplacer l'article de référence presque tel quel, mais je suis heureux d'avoir compris l'utilisation approximative de la régression.

Recommended Posts

Bibliothèque GBDT: j'ai essayé la prédiction de la consommation de carburant (retour) avec CatBoost
J'ai essayé l'analyse de régression multiple avec régression polypoly
J'ai essayé fp-growth avec python
J'ai essayé de gratter avec Python
J'ai essayé d'utiliser la bibliothèque Python de Ruby avec PyCall
J'ai essayé Learning-to-Rank avec Elasticsearch!
J'ai essayé le clustering avec PyCaret
J'ai essayé la bibliothèque changefinder!
J'ai essayé gRPC avec Python
J'ai essayé de gratter avec du python
J'ai essayé de résumer des phrases avec summpy
J'ai essayé webScraping avec python.
J'ai essayé de déplacer de la nourriture avec SinGAN
J'ai essayé d'implémenter DeepPose avec PyTorch
J'ai essayé la détection de visage avec MTCNN
J'ai essayé d'exécuter prolog avec python 3.8.2.
J'ai essayé la communication SMTP avec Python
J'ai essayé la génération de phrases avec GPT-2
J'ai essayé d'apprendre LightGBM avec Yellowbrick
J'ai essayé la reconnaissance faciale avec OpenCV
Prédiction des ondes de Sin (retour) avec Pytorch
(Apprentissage automatique) J'ai essayé de comprendre attentivement la régression linéaire bayésienne avec l'implémentation
J'ai essayé de visualiser le modèle avec la bibliothèque d'apprentissage automatique low-code "PyCaret"
J'ai essayé d'envoyer un SMS avec Twilio
J'ai essayé d'utiliser Amazon SQS avec django-celery
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé d'utiliser du sélénium avec du chrome sans tête
J'ai essayé l'analyse factorielle avec des données Titanic!
J'ai essayé d'apprendre avec le Titanic de Kaggle (kaggle②)
J'ai essayé le rendu non réaliste avec Python + opencv
J'ai essayé un langage fonctionnel avec Python
J'ai essayé la récurrence avec Python ② (séquence de nombres Fibonatch)
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé de jouer avec l'image avec Pillow
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé la reconnaissance d'image simple avec Jupyter
J'ai essayé le réglage fin de CNN avec Resnet
J'ai essayé le traitement du langage naturel avec des transformateurs.
# J'ai essayé quelque chose comme Vlookup avec Python # 2
J'ai essayé d'extraire des expressions uniques avec la bibliothèque de traitement du langage naturel GiNZA
J'ai essayé Hello World avec un langage OS + C 64 bits sans utiliser de bibliothèque
J'ai essayé d'implémenter Cifar10 avec la bibliothèque SONY Deep Learning NNabla [Nippon Hurray]