Wagen Sie es, Kaggle mit Java herauszufordern (1)

Überblick

Ist es möglich, in Kaggle nur mit Java zu kämpfen? Ich habe es gewagt, es zu versuchen.

Was ist in diesem Artikel zu tun?

Ich möchte Titanic: Maschinelles Lernen aus einer Katastrophe, das als Herausforderung für Anfänger von Kaggle bekannt ist, nur in Java implementieren. .. Die Herausforderung besteht darin, das Überleben anhand des Namens, des Geschlechts, des Alters, der Ticketinformationen usw. des Kunden an Bord der Titanic vorherzusagen.

titanic.png

Es erstellt ein Modell, das aus den angegebenen Trainingsdaten ("train.csv") gelernt wurde, sagt das Überleben der in den Testdaten enthaltenen Person ("test.csv") voraus und konkurriert um die korrekte Antwortrate des Ergebnisses. Ist das ein Wettbewerb (obwohl ich denke, dass die Vorhersage von Leben oder Tod wie ein Spiel ist ...).

Lassen Sie uns zunächst die Überlebenden mit einer minimalen Implementierung vorhersagen.

Implementieren Sie dann Folgendes:

--Datenanalyse --Datenvorverarbeitung --Feature Quantity Engineering

Java-Bibliothek für maschinelles Lernen

Und was sind vorher überhaupt die Java-Bibliotheken für maschinelles Lernen? Ich denke, die folgenden sind berühmt.

Dieses Mal werde ich versuchen, Tribuo daraus zu verwenden.

Prognostizieren Sie Überlebende mit minimaler Implementierung

Aufbau einer Entwicklungsumgebung

Da es sich um Java handelt, erstellen wir vorerst mit IntelliJ ein Überlebensvorhersagemodell.

Erstellen Sie ein Maven-Projekt mit IntelliJ wie unten gezeigt

Screenshot from 2020-09-23 22-04-30.png

Fügen Sie Folgendes zu pom.xml hinzu.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Klicken Sie anschließend auf die Schaltfläche "Alle herunterladen" auf dieser Seite, um die für die Vorhersage erforderlichen Daten herunterzuladen, zu dekomprimieren und anschließend den Maven zu erstellen. Kopieren Sie es in Ihr Projektverzeichnis.

Die Verzeichnisstruktur ist wie folgt.

Screenshot from 2020-10-07 22-04-01.png

Lesen von Bibliotheksdaten

Lesen wir zuerst die CSV-Datei. Implementieren und führen Sie Folgendes aus:

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train.csv"),"Survived");

Aber die NumberFormatException ist ...

Exception in thread "main" java.lang.NumberFormatException: For input string: "S"
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:2043)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:23)

Wenn ich den Quellcode von Tribuo lese, scheint es, dass dieses Verhalten mit der Implementierung geändert werden kann, die davon ausgeht, dass die CSV-Datei nur aus numerischen Werten besteht (immer "Double.parseDouble ()"). Da ist gar nichts. Vielleicht ist Tribuos aktuelle Designphilosophie, dass die Vorverarbeitung von Daten außerhalb des Verantwortungsbereichs liegt.

Sie sollten mindestens die nicht numerischen Spalten in der CSV-Datei löschen. Sie können CSV-Dateien mit Apache Commons CSV usw. betreiben. In Erwartung der Zukunft werden wir jedoch "DFLib" einführen, eine Bibliothek, die anscheinend in der Lage ist, Vorverarbeitung durchzuführen. DFLib ist eine leichtgewichtige Java-Implementierung von Pandas in Python, die Apache Commons CSV intern verwendet.

<dependency>
    <groupId>com.nhl.dflib</groupId>
    <artifactId>dflib-csv</artifactId>
    <version>0.8</version>
</dependency>

Löschen Sie vor dem Laden der CSV-Datei mit "CSVLoader" die CSV-Spalten "Name", "Geschlecht", "Ticket", "Kabine" und "Eingeschifft" (eingrenzen auf die erforderlichen Spalten), wie unten gezeigt. , In CSV-Datei speichern.

DataFrame df = Csv.loader().load("titanic/train.csv");
DataFrame selectColumns = df.selectColumns("Survived", "Pclass", "Age", "SibSp", "Parch", "Fare");
Csv.save(selectColumns, "titanic/train_removed.csv");

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train_removed.csv"),"Survived");

Versuchen Sie es nochmal.

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:31)

Aber nochmal "NumberFormatException". Es scheint, dass der fehlende Wert in der Spalte "Alter" als leeres Zeichen behandelt wird. Da DataFrame eine Methode namens fillNulls () hat, können Sie null durch null usw. auf einmal ersetzen, aber der fehlende Lesewert scheint als leeres Zeichen (" ") interpretiert zu werden, also die Absicht Es funktioniert nicht (obwohl möglicherweise ein Problem beim Laden der CSV-Datei vorliegt ...). Hier löschen wir also auch die Spalte "Alter".

DataFrame selectedDataFrame = df.selectColumns("Survived", "Pclass", "SibSp", "Parch", "Fare");

Erstellen eines Basismodells

Diesmal hat es funktioniert. Jetzt bauen wir ein Modell und trainieren. Die Mindestimplementierung bei Verwendung der logistischen Regression lautet wie folgt:

TrainTestSplitter dataSplitter = new TrainTestSplitter<>(dataource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(dataSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(dataSplitter.getTest());

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
System.out.println(evaluation);

Ich habe endlich das Ergebnis bekommen. Von den 268 Verifizierungsdaten sind 163 korrekt, mit einer korrekten Antwortrate von 60,8%.

Class                           n          tp          fn          fp      recall        prec          f1
0                             170          89          81          24       0.524       0.788       0.629
1                              98          74          24          81       0.755       0.477       0.585
Total                         268         163         105         105
Accuracy                                                                    0.608
Micro Average                                                               0.608       0.608       0.608
Macro Average                                                               0.639       0.633       0.607
Balanced Error Rate                                                         0.361

Ausgabe / Einreichung der Prognose

Nachdem das Modell erstellt wurde, lesen wir die Testdaten und machen eine Vorhersage. Lassen Sie uns Folgendes implementieren und überprüfen.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

"CsvLoader.loadDataSource ()" scheint jedoch den Namen der Zielvariablen im zweiten Argument zu erfordern, und ich habe "Survived" übergeben, aber ich erhalte eine Fehlermeldung, wenn test.csv dieses "Survived" nicht hat. Ich habe.

Exception in thread "main" java.lang.IllegalArgumentException: Response Survived not found in file file:/home/tamura/git/tribuo-examples/titanic/test_removed.csv
	at org.tribuo.data.csv.CSVLoader.validateResponseNames(CSVLoader.java:286)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:244)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:74)

Während Sie sich über "Warum sollte die Zielvariable in die CSV-Datei aufgenommen werden" beschweren, gibt es keine Hilfe dafür. Fügen Sie daher die Spalte "Überlebt" in der CSV-Datei zur Übermittlung (gender_submission.csv) zu "DataFrame" hinzu. Ich beschloss zu täuschen.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame dfSubmission = Csv.loader().load("titanic/gender_submission.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
selectedDfTest = selectedDfTest.hConcat(dfSubmission.dropColumns("PassengerId"));
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

Wenn ich diesmal nachdenklich bin, ist die bekannte Ausnahme ...

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:73)

In test.csv fehlt nur ein Wert in der Spalte "Tarif" ... Dieser fehlende Wert ist ebenfalls ein leeres Zeichen ("" ") und nicht null, sodass" DataFrame.fillNulls () "nicht verwendet werden kann. .. Als Ergebnis der Prüfung des Quellcodes stellte ich fest, dass es möglich ist, leere Zeichen durch den folgenden Schreibstil zu ersetzen.

selectedDfTest = selectedDfTest.convertColumn("Fare", s -> "".equals(s) ? "0": s);

Das vorhergesagte Ergebnis wird jetzt ausgegeben.

[Prediction(maxLabel=(0,0.5474041397777752),outputScores={0=(0,0.5474041397777752)1=(1,0.4525958602222247}), Prediction(maxLabel=(0,0.6969779586356148),outputScores={0=(0,0.6969779586356148)1=(1,0.303022041364385}), Prediction(maxLabel=(1,0.5302004352989867),outputScores={0=(0,0.46979956470101314)1=(1,0.530200435298986}), Prediction(maxLabel=(0,0.52713643586377),outputScores={0=(0,0.52713643586377)1=(1,0.4728635641362}), Prediction(maxLabel=(0,0.5071805368465395),outputScores={0=(0,0.5071805368465395)1=(1,0.492819463153460}), Prediction(maxLabel=(0,0.5134002908191431),outputScores={0=(0,0.5134002908191431)1=(1,0.4865997091808569}),
・ ・ ・

Sie müssen es lediglich zur Übermittlung in einer CSV-Datei speichern. Es ist einfach mit Python zu implementieren, aber es ist nicht überraschend einfach mit DFLib oder Tribuo zu implementieren, und ich habe mir die Dokumentation und den Quellcode eine Weile angesehen, aber es kann einige Zeit dauern, also der Java-Standard Ich habe es wie folgt in die API implementiert.

AtomicInteger counter = new AtomicInteger(891);
StringBuilder sb = new StringBuilder();
predicts.stream().forEach(p -> sb.append(String.valueOf(counter.addAndGet(1) + "," + p.getOutput().toString().substring(1,2)) + "\n"));
try (FileWriter fw = new FileWriter("titanic/submission.csv");){
    fw.write("PassengerId,Survived\n");
    fw.write(sb.toString());
} catch (IOException ex) {
    ex.printStackTrace();
}

Da das zweite Zeichen von "p.getOutput (). ToString ()" der vorhergesagte Wert (0 oder 1) ist, ist es eine schlechte Implementierung, es herauszunehmen und mit "java.io.FileWriter" zu schreiben. Übrigens ist "Zähler" 892 bis 1309 der "Passagier-ID", die in der eingereichten Datei enthalten ist.

Nachdem die CSV-Datei zur Übermittlung ausgegeben wurde, laden Sie sie auf die Website von Kaggle hoch und überprüfen Sie die Punktzahl.

Screenshot from 2020-10-02 17-27-35.png

Die Punktzahl auf Kaggle betrug 0,56220. Es ist niedrig, aber vorerst konnte ich es hochladen.

Über die Fortsetzung

Dieses Mal habe ich nur das Minimum versucht, eine CSV-Datei zur Übermittlung auszugeben. Beim nächsten Mal möchte ich überprüfen, ob Daten auf dem Jupyter-Notizbuch visualisiert, vorverarbeitet, optimiert, zusammengesetzt usw. werden können.

Ergänzung

In dieser Implementierung wurde die CSV-Datei mit "CSVLoader" geladen. Ich habe es als selbstverständlich angesehen, diese Klasse beim Laden von CSV-Dateien zu verwenden. Laut dem Tribuo-Entwickler ist es in diesem Fall jedoch angemessen, "RowProcessor" und "CSVDataSource" anstelle von "CSVLoader" zu verwenden. Wenn Sie Folgendes implementieren, erhalten Sie das gleiche Ergebnis wie im obigen Code.

Tokenizer tokenizer = new BreakIteratorTokenizer(Locale.US);
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Survived","0",labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("Pclass", new DoubleFieldProcessor("Pclass"));
fieldProcessors.put("SibSp", new DoubleFieldProcessor("SibSp"));
fieldProcessors.put("Parch", new DoubleFieldProcessor("Parch"));
fieldProcessors.put("Fare", new DoubleFieldProcessor("Fare"));

RowProcessor<Label> rp = new RowProcessor<>(responseProcessor,fieldProcessors);
Path path = Paths.get("titanic/train.csv");
CSVDataSource<Label> source = new CSVDataSource<>(path,rp,true);
TrainTestSplitter dataSplitter = new TrainTestSplitter<>(source, 0.7, 1L);

Durch die Verwendung von "Identity Processor" für die Textzeichenfolge scheint es außerdem, dass die One-Hot-Codierung automatisch durchgeführt wird. Wenn Sie es beispielsweise auf die Spalte "Geschlecht" anwenden, die zwei Werte hat: "männlich" und "weiblich",

fieldProcessors.put("Sex", new IdentityProcessor("Sex"));

Die Spalten "Sex @ männlich" und "Sex @ weiblich" mit einem Wert von 0/1 werden hinzugefügt (obwohl eine Spalte ausreicht, da zwei Auswahlmöglichkeiten bestehen). Dies erhöht auch die Genauigkeit auf "0,746".

Wenn Sie es auf die Spalte "Eingeschifft" anwenden, die drei Werte hat: "C", "S" und "K",

fieldProcessors.put("Embarked", new IdentityProcessor("Embarked"));

Ich möchte sagen, dass die Genauigkeit weiter zunehmen wird, aber in Wirklichkeit wird sie abnehmen. Dies liegt daran, dass die Spalte "Eingeschifft" nur eine Zeile mit einem leeren Wert enthält, was dazu führt, dass sich die Anzahl der Spalten mit Trainingsdaten und Verifizierungsdaten verschiebt und nicht ordnungsgemäß funktioniert. Ich möchte den leeren Wert durch den häufigsten Wert "S" ersetzen, aber es ist ein schmerzhafter Ort, der wie Python nicht schnell realisiert werden kann ... das werde ich beim nächsten Mal wieder aufnehmen.

Eindrücke / Schlussfolgerungen

Meine Antwort auf die Frage "Kann ich in Kaggle nur mit Java kämpfen?" Ist "zu diesem Zeitpunkt ziemlich schwierig" (obwohl ich die Schlussfolgerung von Anfang an kannte). Da Tribuo gerade veröffentlicht wurde, gibt es im Internet nur sehr wenige Informationen, und es gibt nicht genügend Anleitungen und Javadocs. Daher dauert die Untersuchung sehr lange. Und erstens reicht die Funktion nicht aus. Es gibt nicht einmal eine Methode, die den Nullwert in den Daten mit einem Median füllt.

Die aktuelle Situation ist, dass Dinge, die mit Python einfach erledigt werden können, nicht einfach erledigt werden können, aber umgekehrt fehlen viele Funktionen, sodass ich denke, dass es viele Möglichkeiten gibt, einen Beitrag zu leisten (Pull-Anfrage). Wenn Sie ein Java-Programmierer sind, der sich für maschinelles Lernen interessiert, warum tragen Sie nicht zu diesem Projekt bei und studieren? Ich fange gerade an, es zu berühren, also werde ich etwas mehr untersuchen, um zu sehen, wie viel ich tun kann.

Recommended Posts

Wagen Sie es, Kaggle mit Java herauszufordern (1)
Herausforderung, mit verstümmelten Zeichen mit Java AudioSystem.getMixerInfo () umzugehen
Stellen Sie mit Java eine Verbindung zur Datenbank her
Stellen Sie mit Java eine Verbindung zu MySQL 8 her
Java mit Ramen lernen [Teil 1]
[Java] Mit Arrays.asList () zu beachtende Punkte
Ich habe versucht, mit Java zu interagieren
Java, Arrays für Anfänger
Eine Geschichte, die ich mit Java nur schwer herausfordern konnte
So kompilieren Sie Java mit VsCode & Ant
[Java] Fassen Sie zusammen, wie Sie mit der Methode equals vergleichen können
Wenn Sie es wagen, Integer mit "==" zu vergleichen ...
Einführung in Algorithmen mit Java-Suche (Tiefenprioritätssuche)
Einfach mit regulären Java-Ausdrücken zu stolpern
Einführung in Algorithmen mit Java --Search (Breitenprioritätssuche)
[Java] Einführung in Java
Einführung in Java
[Java] So testen Sie, ob es in JUnit null ist
Ich habe versucht, eine Standardauthentifizierung mit Java durchzuführen
Einführung in Algorithmen mit Java --Search (Bit Full Search)
Stellen Sie Java-Webanwendungen mit maven in Azure bereit
Verwendung des Java-Frameworks mit AWS Lambda! ??
Ich möchte Java8 für jeden mit Index verwenden
Verwendung der Java-API mit Lambda-Ausdrücken
Erste Schritte mit Kotlin zum Senden an Java-Entwickler
Versuchen Sie, TCP / IP + NIO mit JAVA zu implementieren
[Java] Artikel zum Hinzufügen einer Validierung mit Spring Boot 2.3.1.
Einfacher LINE BOT mit Java Servlet
Ich habe versucht, den Block mit Java zu brechen (1)
Installieren Sie Java mit Homebrew
Aufrufen von Funktionen in großen Mengen mit Java Reflection
Listenverarbeitung zum Verstehen mit Bildern --java8 stream / javaslang-
Senden Sie einen Job an AWS Batch mit Java (Eclipse)
Änderungen von Java 8 zu Java 11
Summe von Java_1 bis 100
Ich habe versucht, TCP / IP + BIO mit JAVA zu implementieren
Wechseln Sie die Plätze mit Java
Installieren Sie Java mit Ansible
[Java 11] Ich habe versucht, Java auszuführen, ohne mit Javac zu kompilieren
[Java] So lassen Sie die Federkonstruktorinjektion mit Lombok weg
So stellen Sie Java mit Serverless Framework für AWS Lambda bereit
[Java] Stellen Sie eine Verbindung zu MySQL her
[Java] Verschlüsselung mit AES-Verschlüsselung mit Standardbibliothek
Bequemer Download mit JAVA
[Java] Verweisen Sie auf und setzen Sie private Variablen mit Reflektion
Schalten Sie Java mit direnv
Kotlins Verbesserungen an Java
Von Java zu Ruby !!
HTTPS-Verbindung mit Java zum selbstsignierten Zertifikatsserver
Ich habe versucht, Sterling Sort mit Java Collector zu implementieren
Java-Download mit Ansible
Ich möchte Bildschirmübergänge mit Kotlin und Java machen!
Versuchen Sie, mit Java eine Verbindung zu AzureCosmosDB Emulator for Docker herzustellen
So erstellen Sie eine Java-Entwicklungsumgebung mit VS Code
Lass uns mit Java kratzen! !!
Einführung in den Java-Befehl
Erstellen Sie Java mit Wercker
Ich habe nc (netcat) normalerweise mit JAVA gemacht
[Java] So unterbrechen Sie eine Zeile mit StringBuilder
Endian-Konvertierung mit JAVA