Osez défier Kaggle avec Java (1)

Aperçu

Est-il possible de se battre à Kaggle en utilisant uniquement Java? J'ai osé l'essayer.

Que faire dans cet article

Je voudrais implémenter le défi bien connu de Kaggle pour les débutants Titanic: Machine Learning from Disaster en Java uniquement. .. Le défi est de prédire la survie en fonction du nom, du sexe, de l'âge, des informations sur les billets, etc. des clients à bord du Titanic.

titanic.png

Il construit un modèle appris à partir des données d'entraînement données (train.csv), prédit la survie de la personne incluse dans les données de test ( test.csv), et est en concurrence pour le taux de réponse correct du résultat. Est-ce cette compétition (même si je pense que prédire la vie ou la mort comme un jeu ...).

Premièrement, prédisons les survivants avec une implémentation minimale.

--Construire un environnement de développement

Ensuite, implémentez ce qui suit:

--L'analyse des données

Bibliothèque d'apprentissage automatique Java

Et avant cela, quelles sont les bibliothèques d'apprentissage automatique Java en premier lieu? Je pense que les suivants sont célèbres.

Cette fois, je vais essayer d'utiliser Tribuo à partir de cela.

Prédire les survivants avec une mise en œuvre minimale

Construire un environnement de développement

Puisqu'il s'agit de Java, construisons un modèle de prédiction de survie avec IntelliJ pour le moment.

Créez un projet Maven avec IntelliJ comme indiqué ci-dessous

Screenshot from 2020-09-23 22-04-30.png

Ajoutez ce qui suit à pom.xml.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Ensuite, cliquez sur le bouton «Tout télécharger» sur cette page pour télécharger les données nécessaires à la prédiction, décompressez-les, puis créez le Maven. Copiez-le dans votre répertoire de projet.

La structure des répertoires est la suivante.

Screenshot from 2020-10-07 22-04-01.png

Lecture des données de la bibliothèque

Lisons d'abord le fichier CSV. Implémentez et exécutez les éléments suivants:

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train.csv"),"Survived");

Mais le NumberFormatException est ...

Exception in thread "main" java.lang.NumberFormatException: For input string: "S"
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:2043)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:23)

Quand je lis le code source de Tribuo, il semble que ce comportement puisse être changé avec l'implémentation qui suppose que le fichier CSV se compose uniquement de valeurs numériques (toujours Double.parseDouble ()). Il n'y en a pas. Peut-être que la philosophie de conception actuelle de Tribuo est que le prétraitement des données ne relève pas de la responsabilité.

Vous devez au moins supprimer les colonnes non numériques du fichier CSV. Vous pouvez utiliser des fichiers CSV avec Apache Commons CSV, etc., mais en prévision de l'avenir, nous allons introduire "DFLib" qui est une bibliothèque qui semble être capable d'effectuer un prétraitement. DFLib est une implémentation Java légère de Pandas en Python qui utilise Apache Commons CSV en interne.

<dependency>
    <groupId>com.nhl.dflib</groupId>
    <artifactId>dflib-csv</artifactId>
    <version>0.8</version>
</dependency>

Avant de charger le fichier CSV avec CSVLoader, supprimez les colonnes CSV" Nom "," Sexe "," Ticket "," Cabine "et" Embarqué "(restreindre aux colonnes requises) comme indiqué ci-dessous. , Enregistrer dans un fichier CSV.

DataFrame df = Csv.loader().load("titanic/train.csv");
DataFrame selectColumns = df.selectColumns("Survived", "Pclass", "Age", "SibSp", "Parch", "Fare");
Csv.save(selectColumns, "titanic/train_removed.csv");

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train_removed.csv"),"Survived");

Réessayer.

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:31)

Mais encore une fois, NumberFormatException. Il semble que la valeur manquante contenue dans la colonne "Age" soit traitée comme un caractère vide. Puisque DataFrame a une méthode appelée fillNulls (), vous pouvez remplacer null par zéro etc. à la fois, mais la valeur manquante de lecture semble être interprétée comme un caractère vide (" "), donc l'intention Cela ne fonctionne pas (bien qu'il puisse y avoir un problème avec la façon de charger le fichier CSV ...). Donc, ici, nous effaçons également la colonne "Age".

DataFrame selectedDataFrame = df.selectColumns("Survived", "Pclass", "SibSp", "Parch", "Fare");

Création d'un modèle de base

Cette fois, cela a fonctionné. Construisons maintenant un modèle et formons-nous. La mise en œuvre minimale lors de l'utilisation de la régression logistique est la suivante:

TrainTestSplitter dataSplitter = new TrainTestSplitter<>(dataource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(dataSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(dataSplitter.getTest());

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
System.out.println(evaluation);

J'ai enfin obtenu le résultat. Sur les 268 données de vérification, 163 sont correctes, avec un taux de réponse correcte de 60,8%.

Class                           n          tp          fn          fp      recall        prec          f1
0                             170          89          81          24       0.524       0.788       0.629
1                              98          74          24          81       0.755       0.477       0.585
Total                         268         163         105         105
Accuracy                                                                    0.608
Micro Average                                                               0.608       0.608       0.608
Macro Average                                                               0.639       0.633       0.607
Balanced Error Rate                                                         0.361

Sortie / soumission des prévisions

Maintenant que le modèle est construit, lisons les données de test et faisons une prédiction. Implémentons et vérifions comme suit.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

Cependant, CsvLoader.loadDataSource () semble exiger le nom de la variable objective dans le deuxième argument, et j'ai passé "Survived", mais j'obtiens une erreur si test.csv n'a pas ce "Survived". J'ai.

Exception in thread "main" java.lang.IllegalArgumentException: Response Survived not found in file file:/home/tamura/git/tribuo-examples/titanic/test_removed.csv
	at org.tribuo.data.csv.CSVLoader.validateResponseNames(CSVLoader.java:286)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:244)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:74)

Tout en se plaignant "Pourquoi la variable objective devrait-elle être incluse dans le fichier CSV", il n'y a aucune aide pour cela, alors ajoutez la colonne "Survived" dans le fichier csv pour soumission (gender_submission.csv) à DataFrame. , J'ai décidé de tromper.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame dfSubmission = Csv.loader().load("titanic/gender_submission.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
selectedDfTest = selectedDfTest.hConcat(dfSubmission.dropColumns("PassengerId"));
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

Quand je le lance en pensant cette fois, l'exception familière est ...

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:73)

En regardant test.csv, il n'y a qu'une seule valeur manquante dans la colonne "Fare" ... Cette valeur manquante est aussi un caractère vide (" "), non nul, donc DataFrame.fillNulls () ne peut pas être utilisé. .. À la suite de l'examen du code source, j'ai trouvé qu'il est possible de remplacer les caractères vides par le style d'écriture suivant.

selectedDfTest = selectedDfTest.convertColumn("Fare", s -> "".equals(s) ? "0": s);

Le résultat prévu est maintenant sorti.

[Prediction(maxLabel=(0,0.5474041397777752),outputScores={0=(0,0.5474041397777752)1=(1,0.4525958602222247}), Prediction(maxLabel=(0,0.6969779586356148),outputScores={0=(0,0.6969779586356148)1=(1,0.303022041364385}), Prediction(maxLabel=(1,0.5302004352989867),outputScores={0=(0,0.46979956470101314)1=(1,0.530200435298986}), Prediction(maxLabel=(0,0.52713643586377),outputScores={0=(0,0.52713643586377)1=(1,0.4728635641362}), Prediction(maxLabel=(0,0.5071805368465395),outputScores={0=(0,0.5071805368465395)1=(1,0.492819463153460}), Prediction(maxLabel=(0,0.5134002908191431),outputScores={0=(0,0.5134002908191431)1=(1,0.4865997091808569}),
・ ・ ・

Tout ce que vous avez à faire est de l'enregistrer dans un fichier CSV pour soumission. C'est facile à mettre en œuvre avec Python, mais ce n'est pas étonnamment facile à implémenter avec DFLib ou Tribuo, et j'ai regardé la documentation et le code source pendant un moment, mais cela peut prendre un certain temps, donc le standard Java Je l'ai implémenté dans l'API comme suit.

AtomicInteger counter = new AtomicInteger(891);
StringBuilder sb = new StringBuilder();
predicts.stream().forEach(p -> sb.append(String.valueOf(counter.addAndGet(1) + "," + p.getOutput().toString().substring(1,2)) + "\n"));
try (FileWriter fw = new FileWriter("titanic/submission.csv");){
    fw.write("PassengerId,Survived\n");
    fw.write(sb.toString());
} catch (IOException ex) {
    ex.printStackTrace();
}

Puisque le deuxième caractère de p.getOutput (). ToString () est la valeur prédite (0 ou 1), c'est une mauvaise implémentation de la retirer et de l'écrire avec java.io.FileWriter. À propos, «compteur» est de 892 à 1309 de «Passenger Id» inclus dans le fichier soumis.

Maintenant que le fichier CSV à soumettre a été sorti, téléchargez-le sur le site de Kaggle et vérifiez le score.

Screenshot from 2020-10-02 17-27-35.png

Le score sur Kaggle était de 0,56220. C'est bas, mais pour le moment, j'ai pu le télécharger.

À propos de la suite

Cette fois, je n'ai essayé que le minimum de sortie d'un fichier CSV pour soumission. La prochaine fois, j'aimerais vérifier si les données peuvent être visualisées sur le Jupyter Notebook, le prétraitement, le réglage, l'assemblage, etc.

Supplément

Dans cette implémentation, le fichier CSV a été chargé en utilisant CSVLoader. J'ai pris pour acquis d'utiliser cette classe lors du chargement de fichiers CSV. Cependant, selon le développeur Tribuo, dans ce cas, il est en fait approprié d'utiliser «RowProcessor» et «CSVDataSource» au lieu de «CSVLoader». La mise en œuvre de ce qui suit donnera le même résultat que le code ci-dessus.

Tokenizer tokenizer = new BreakIteratorTokenizer(Locale.US);
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Survived","0",labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("Pclass", new DoubleFieldProcessor("Pclass"));
fieldProcessors.put("SibSp", new DoubleFieldProcessor("SibSp"));
fieldProcessors.put("Parch", new DoubleFieldProcessor("Parch"));
fieldProcessors.put("Fare", new DoubleFieldProcessor("Fare"));

RowProcessor<Label> rp = new RowProcessor<>(responseProcessor,fieldProcessors);
Path path = Paths.get("titanic/train.csv");
CSVDataSource<Label> source = new CSVDataSource<>(path,rp,true);
TrainTestSplitter dataSplitter = new TrainTestSplitter<>(source, 0.7, 1L);

De plus, en utilisant ʻIdentity Processor` pour la chaîne de texte, il semble que l'encodage One-hot soit effectué automatiquement. Par exemple, si vous l'appliquez à la colonne "Sexe", qui a deux valeurs, "homme" et "femme",

fieldProcessors.put("Sex", new IdentityProcessor("Sex"));

Les colonnes «Sex @ male» et «Sex @ female» avec une valeur de 0/1 sont ajoutées (bien qu'une colonne soit suffisante car il y a deux choix). Cela augmente également la précision à "0,746".

De même, si vous l'appliquez à la colonne "Embarqué", qui a trois valeurs, "C", "S" et "K",

fieldProcessors.put("Embarked", new IdentityProcessor("Embarked"));

Je voudrais dire que la précision augmentera encore, mais en réalité elle diminuera. Cela est dû au fait que la colonne «Embarqué» n'a qu'une seule ligne avec une valeur vide, ce qui entraîne un décalage du nombre de colonnes de données d'entraînement et de données de vérification et ne fonctionne pas correctement. Je voudrais remplacer la valeur vide par la valeur la plus fréquente S, mais c'est un endroit douloureux qui ne peut pas être réalisé rapidement comme Python ... Je vais l'inclure à nouveau la prochaine fois.

Impressions / conclusions

Ma réponse à la question "Puis-je combattre dans Kaggle en utilisant uniquement Java?" Est "assez difficile à ce stade" (même si je connaissais la conclusion depuis le début). Depuis que Tribuo vient de sortir, il y a très peu d'informations sur Internet, et il n'y a pas assez de guides et de Javadocs, donc il faut beaucoup de temps pour enquêter. Et, en premier lieu, la fonction ne suffit pas. Il n'y a même pas de méthode qui remplit la valeur nulle dans les données avec une médiane.

La situation actuelle est que les choses qui peuvent être faites facilement avec Python ne peuvent pas être faites facilement, mais à l'inverse, il y a beaucoup de fonctions manquantes, donc je pense qu'il existe de nombreuses opportunités de contribuer (pull request). Si vous êtes un programmeur Java intéressé par l'apprentissage automatique, pourquoi ne contribuez-vous pas à ce projet en plus d'étudier? Je commence juste à y toucher, alors je vais enquêter un peu plus pour voir ce que je peux faire.

Recommended Posts

Osez défier Kaggle avec Java (1)
Défi pour gérer les caractères déformés avec Java AudioSystem.getMixerInfo ()
Connectez-vous à DB avec Java
Connectez-vous à MySQL 8 avec Java
Java pour apprendre avec les ramen [Partie 1]
[Java] Points à noter avec Arrays.asList ()
J'ai essayé d'interagir avec Java
Java, des tableaux pour débuter avec les débutants
Une histoire que j'ai eu du mal à défier le pro de la concurrence avec Java
Comment compiler Java avec VsCode & Ant
[Java] Résumez comment comparer avec la méthode equals
Si vous osez comparer Integer avec "==" ...
Introduction aux algorithmes avec java-Search (recherche prioritaire en profondeur)
Facile à parcourir avec les expressions régulières Java
Introduction aux algorithmes avec java --Search (recherche de priorité de largeur)
[Java] Introduction à Java
Introduction à Java
[Java] Comment tester s'il est nul dans JUnit
J'ai essayé de faire une authentification de base avec Java
Introduction aux algorithmes avec java --Search (bit full search)
Déployez des applications Web Java sur Azure avec maven
Comment utiliser le framework Java avec AWS Lambda! ??
Je veux utiliser java8 forEach avec index
Comment utiliser l'API Java avec des expressions lambda
Premiers pas avec Kotlin à envoyer aux développeurs Java
Essayez d'implémenter TCP / IP + NIO avec JAVA
[Java] Article pour ajouter une validation avec Spring Boot 2.3.1.
Facile à créer LINE BOT avec Java Servlet
J'ai essayé de casser le bloc avec java (1)
Installez java avec Homebrew
Comment appeler des fonctions en bloc avec la réflexion Java
Traitement des listes à comprendre avec des images - java8 stream / javaslang-
Soumettre une tâche à AWS Batch avec Java (Eclipse)
Changements de Java 8 à Java 11
Somme de Java_1 à 100
J'ai essayé d'implémenter TCP / IP + BIO avec JAVA
Changer de siège avec Java
Installez Java avec Ansible
[Java 11] J'ai essayé d'exécuter Java sans compiler avec javac
[Java] Comment omettre l'injection de constructeur de ressort avec Lombok
Comment déployer Java sur AWS Lambda avec Serverless Framework
[Java] Connectez-vous à MySQL
[Java] Comment chiffrer avec le chiffrement AES avec une bibliothèque standard
Téléchargement confortable avec JAVA
[Java] Se référer et définir des variables privées avec réflexion
Changer java avec direnv
Améliorations de Kotlin à Java
De Java à Ruby !!
Connexion HTTPS avec Java au serveur de certificats auto-signé
J'ai essayé d'implémenter Sterling Sort avec Java Collector
Téléchargement Java avec Ansible
Je veux faire des transitions d'écran avec kotlin et java!
Essayez de vous connecter à l'émulateur AzureCosmosDB pour Docker avec Java
Comment créer un environnement de développement Java avec VS Code
Raclons avec Java! !!
Introduction à la commande java
Construire Java avec Wercker
J'avais l'habitude de faire nc (netcat) avec JAVA normalement
[Java] Comment rompre une ligne avec StringBuilder
Conversion Endian avec JAVA