[JAVA] [Machine learning with Apache Spark] Associate the importance (Feature Importance) of variables in a tree model with variable names (explanatory variable names)

Overview

-** In a tree-based learning model with Apache Spark , the importance ( Feature Importance **) of the variable that indicates which variable (explanatory variable) is important in the learning result is linked with the variable name of the training data set. I will explain how to attach it. --Scala people can write neatly with zipWithIndex, so it's not so troublesome, but Java is a little troublesome, so I made it into a library https://github.com/riversun/spark-ml-feature-importance-helper

environment

Models that can use the ** featureImportances ** method

** Feature Importances ** can be calculated with ** Apache Spark ** using the following tree learning model

GBTRegressionModel < / td> Gradient boosting tree regression model
GBTClassificationModel < / td> Gradient boosting tree classification model
RandomForestRegressionModel < / td> Random forest regression model
RandomForestClassificationModel < / td> Random forest classification model
DecisionTreeRegressionModel < / td> Decision tree regression model
DecisionTreeClassificationModel < / td> Decision tree classification model

What is the importance of a variable (** Feature Importance **)?

The convenient part of the tree algorithm is that it is possible to [calculate] which variable was important in building the model as a result of training as "Variable Importance **" (https:: //web.stanford.edu/~hastie/ElemStatLearn/printings/ESLII_print12.pdf) (P.593).

However, since random forests and gradient boosting trees are methods called ** ensemble learning ** that combine multiple models, the importance of variables (** Feature Importance **) is like the regression coefficient. It should be noted that it is not an index for seeking a direct interpretation.

Method to get the importance of a variable (** Feature Importance **)

The training model obtained by training ** RandomForest ** in ** spark.ml **, for example, ** RandomForestRegressionModel ** class has ** RandomForestRegressionModel # featureImportances ** method when using Random Forest as the algorithm. is there.

By calling this method, the importance of the variable (** Feature Importance **) can be calculated.

By the way, the return value of the ** RandomForestRegressionModel # featureImportances ** method is of ** SparseVector ** type, and since the vector contains only the numerical component indicating the importance, the numerical value and which variable It is difficult to understand whether it is supported. It's hard to tell if it's there.

Therefore, the purpose of this paper is to associate ** Feature Importances ** with variable names.

In order to associate, prepare the variable to be passed to the argument in 3 steps.

It is assumed that the model object called randomForestRegressionModel has been acquired on the premise that the learning has been completed.

** STEP1 ** First, get ** Feature Importances ** as follows.

Vector importances=randomForestRegressionModel.featureImportances();

** STEP2 ** Next, get the variable name used as the feature. (Since the variable name was originally decided by myself, it is OK to hard code it)

String featuresCol = this.randomForestRegressionModel.getFeaturesCol()

** STEP3 ** Finally get the schema

StructType schema = predictions.schema();

Pass the three you just prepared to the following methods. Link (+ sort as well) with this method.

List<Importance> zipImportances(Vector featureImportances, String featuresCol, StructType schema) {

    final int indexOfFeaturesCol = (Integer) schema.getFieldIndex(featuresCol).get();

    final StructField featuresField = schema.fields()[indexOfFeaturesCol];

    final Metadata metadata = featuresField
            .metadata();

    final Metadata featuresFieldAttrs = metadata
            .getMetadata("ml_attr")
            .getMetadata("attrs");

    final Map<Integer, String> idNameMap = new HashMap<>();

    final String[] fieldKeys = { "nominal", "numeric", "binary" };

    final Collector<Metadata, ?, HashMap<Integer, String>> metaDataMapperFunc = Collectors
            .toMap(
                    metaData -> (int) metaData.getLong("idx"), // key of map
                    metaData -> metaData.getString("name"), // value of map
                    (oldVal, newVal) -> newVal,
                    HashMap::new);

    for (String fieldKey : fieldKeys) {
        if (featuresFieldAttrs.contains(fieldKey)) {
            idNameMap.putAll(Arrays
                    .stream(featuresFieldAttrs.getMetadataArray(fieldKey))
                    .collect(metaDataMapperFunc));
        }
    }

    final double[] importanceScores = featureImportances.toArray();

    final List<Importance> rawImportanceList = IntStream
            .range(0, importanceScores.length)
            .mapToObj(idx -> new Importance(idx, idNameMap.get(idx), importanceScores[idx], 0))
            .collect(Collectors.toList());

    final List<Importance> descSortedImportanceList = rawImportanceList
            .stream()
            .sorted(Comparator.comparingDouble((Importance ifeature) -> ifeature.score).reversed())
            .collect(Collectors.toList());

    for (int i = 0; i < descSortedImportanceList.size(); i++) {
        descSortedImportanceList.get(i).rank = i;
    }

    final List<Importance> finalImportanceList;

    switch (this.sort) {
    case ASCENDING:
        final List<Importance> ascSortedImportantFeatureList = descSortedImportanceList
                .stream()
                .sorted(Comparator.comparingDouble((Importance ifeature) -> ifeature.score))
                .collect(Collectors.toList());

        finalImportanceList = ascSortedImportantFeatureList;
        break;
    case DESCENDING:
        finalImportanceList = descSortedImportanceList;
        break;
    case UNSORTED:
    default:
        finalImportanceList = rawImportanceList;
        break;
    }
    return finalImportanceList;

}

The key to this method

The key to this method is as follows.

final StructField featuresField = schema.fields()[indexOfFeaturesCol];

final Metadata metadata = featuresField
        .metadata();

final Metadata featuresFieldAttrs = metadata
        .getMetadata("ml_attr")
        .getMetadata("attrs");

** schema ** stores the structure of the data set, but there is a variable as a feature quantity for learning in it, and the metadata (design information) of that variable is extracted and the position (index) of the data is extracted. And get the association information of the corresponding variable.

Metadata stored in schema

What happens to the actual metadata can be extracted as the following JSON format. (Since the format changes depending on the Transformer sandwiched in the pipeline processing, the following is an example)

MetaData example(An example of predicting the price of jewelry accessories)



{
   "ml_attr":{
      "attrs":{
         "numeric":[
            {
               "idx":0,
               "name":"weight"
            }
         ],
         "nominal":[
            {
               "vals":[
                  "platinum",
                  "gold",
                  "Silver",
                  "Dia"
               ],
               "idx":1,
               "name":"materialIndex"
            },
            {
               "vals":[
                  "brooch",
                  "necklace",
                  "Earrings",
                  "ring",
                  "bracelet"
               ],
               "idx":2,
               "name":"shapeIndex"
            },
            {
               "vals":[
                  "Domestic famous brand",
                  "Overseas super famous brand",
                  "No brand",
                  "Famous overseas brands"
               ],
               "idx":3,
               "name":"brandIndex"
            },
            {
               "vals":[
                  "Department store",
                  "Directly managed store",
                  "Cheap shop"
               ],
               "idx":4,
               "name":"shopIndex"
            }
         ]
      },
      "num_attrs":5
   }
}

As you can see in this data, you can get the pair of ** idx ** and ** name , so if you map it with ** Vector ** that you can get by ** FeatureImportances ** method, it will be as follows. As a result of a certain learning, we can see how important ( score ) the variable name ( name **) was.

Explanatory variable importance ranking


FeatureInfo [rank=0, score=0.3015643580333446, name=weight]
FeatureInfo [rank=1, score=0.2707593044437997, name=materialIndex]
FeatureInfo [rank=2, score=0.20696065038166056, name=brandIndex]
FeatureInfo [rank=3, score=0.11316392134864546, name=shapeIndex]
FeatureInfo [rank=4, score=0.10755176579254973, name=shopIndex]

Summary

--In Apache Spark spark.ml, I wrote a helper in Java that links the importance of explanatory variable names in a tree model. ――It may be "This kind of guy is haste with zipWithIndex if it's scala!", But it was a story that I wanted a library that would save time and effort with Java.

Whole source code

The source code is below https://github.com/riversun/spark-ml-feature-importance-helper

Maven Also available from Maven repositories

<dependency>
	<groupId>org.riversun</groupId>
	<artifactId>spark-ml-feature-importance-helper</artifactId>
	<version>1.0.0</version>
</dependency>

Recommended Posts

[Machine learning with Apache Spark] Associate the importance (Feature Importance) of variables in a tree model with variable names (explanatory variable names)
Introduction to Machine Learning with Spark "Price Estimate" # 2 Data Preprocessing (Handling of Category Variables)
[Machine learning with Apache Spark] Sparse Vector (sparse vector) and Dense Vector (dense vector)
Find the number of days in a month with Kotlin