Ist es möglich, in Kaggle nur mit Java zu kämpfen? Ich habe es gewagt, es zu versuchen.
Ich möchte Titanic: Maschinelles Lernen aus einer Katastrophe, das als Herausforderung für Anfänger von Kaggle bekannt ist, nur in Java implementieren. .. Die Herausforderung besteht darin, das Überleben anhand des Namens, des Geschlechts, des Alters, der Ticketinformationen usw. des Kunden an Bord der Titanic vorherzusagen.
Es erstellt ein Modell, das aus den angegebenen Trainingsdaten ("train.csv") gelernt wurde, sagt das Überleben der in den Testdaten enthaltenen Person ("test.csv") voraus und konkurriert um die korrekte Antwortrate des Ergebnisses. Ist das ein Wettbewerb (obwohl ich denke, dass die Vorhersage von Leben oder Tod wie ein Spiel ist ...).
Lassen Sie uns zunächst die Überlebenden mit einer minimalen Implementierung vorhersagen.
Implementieren Sie dann Folgendes:
--Datenanalyse --Datenvorverarbeitung --Feature Quantity Engineering
Und was sind vorher überhaupt die Java-Bibliotheken für maschinelles Lernen? Ich denke, die folgenden sind berühmt.
Dieses Mal werde ich versuchen, Tribuo daraus zu verwenden.
Da es sich um Java handelt, erstellen wir vorerst mit IntelliJ ein Überlebensvorhersagemodell.
Erstellen Sie ein Maven-Projekt mit IntelliJ wie unten gezeigt
Fügen Sie Folgendes zu pom.xml
hinzu.
<dependencies>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.0.0</version>
<type>pom</type>
</dependency>
</dependencies>
Klicken Sie anschließend auf die Schaltfläche "Alle herunterladen" auf dieser Seite, um die für die Vorhersage erforderlichen Daten herunterzuladen, zu dekomprimieren und anschließend den Maven zu erstellen. Kopieren Sie es in Ihr Projektverzeichnis.
Die Verzeichnisstruktur ist wie folgt.
Lesen wir zuerst die CSV-Datei. Implementieren und führen Sie Folgendes aus:
LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train.csv"),"Survived");
Aber die NumberFormatException
ist ...
Exception in thread "main" java.lang.NumberFormatException: For input string: "S"
at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:2043)
at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
at java.lang.Double.parseDouble(Double.java:538)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:23)
Wenn ich den Quellcode von Tribuo lese, scheint es, dass dieses Verhalten mit der Implementierung geändert werden kann, die davon ausgeht, dass die CSV-Datei nur aus numerischen Werten besteht (immer "Double.parseDouble ()"). Da ist gar nichts. Vielleicht ist Tribuos aktuelle Designphilosophie, dass die Vorverarbeitung von Daten außerhalb des Verantwortungsbereichs liegt.
Sie sollten mindestens die nicht numerischen Spalten in der CSV-Datei löschen. Sie können CSV-Dateien mit Apache Commons CSV usw. betreiben. In Erwartung der Zukunft werden wir jedoch "DFLib" einführen, eine Bibliothek, die anscheinend in der Lage ist, Vorverarbeitung durchzuführen. DFLib ist eine leichtgewichtige Java-Implementierung von Pandas in Python, die Apache Commons CSV intern verwendet.
<dependency>
<groupId>com.nhl.dflib</groupId>
<artifactId>dflib-csv</artifactId>
<version>0.8</version>
</dependency>
Löschen Sie vor dem Laden der CSV-Datei mit "CSVLoader" die CSV-Spalten "Name", "Geschlecht", "Ticket", "Kabine" und "Eingeschifft" (eingrenzen auf die erforderlichen Spalten), wie unten gezeigt. , In CSV-Datei speichern.
DataFrame df = Csv.loader().load("titanic/train.csv");
DataFrame selectColumns = df.selectColumns("Survived", "Pclass", "Age", "SibSp", "Parch", "Fare");
Csv.save(selectColumns, "titanic/train_removed.csv");
LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train_removed.csv"),"Survived");
Versuchen Sie es nochmal.
Exception in thread "main" java.lang.NumberFormatException: empty String
at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
at java.lang.Double.parseDouble(Double.java:538)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:31)
Aber nochmal "NumberFormatException". Es scheint, dass der fehlende Wert in der Spalte "Alter" als leeres Zeichen behandelt wird. Da DataFrame
eine Methode namens fillNulls ()
hat, können Sie null durch null usw. auf einmal ersetzen, aber der fehlende Lesewert scheint als leeres Zeichen (" "
) interpretiert zu werden, also die Absicht Es funktioniert nicht (obwohl möglicherweise ein Problem beim Laden der CSV-Datei vorliegt ...). Hier löschen wir also auch die Spalte "Alter".
DataFrame selectedDataFrame = df.selectColumns("Survived", "Pclass", "SibSp", "Parch", "Fare");
Diesmal hat es funktioniert. Jetzt bauen wir ein Modell und trainieren. Die Mindestimplementierung bei Verwendung der logistischen Regression lautet wie folgt:
TrainTestSplitter dataSplitter = new TrainTestSplitter<>(dataource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(dataSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(dataSplitter.getTest());
Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
System.out.println(evaluation);
Ich habe endlich das Ergebnis bekommen. Von den 268 Verifizierungsdaten sind 163 korrekt, mit einer korrekten Antwortrate von 60,8%.
Class n tp fn fp recall prec f1
0 170 89 81 24 0.524 0.788 0.629
1 98 74 24 81 0.755 0.477 0.585
Total 268 163 105 105
Accuracy 0.608
Micro Average 0.608 0.608 0.608
Macro Average 0.639 0.633 0.607
Balanced Error Rate 0.361
Nachdem das Modell erstellt wurde, lesen wir die Testdaten und machen eine Vorhersage. Lassen Sie uns Folgendes implementieren und überprüfen.
DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
Csv.save(selectedDfTest, "titanic/test_removed.csv");
ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);
"CsvLoader.loadDataSource ()" scheint jedoch den Namen der Zielvariablen im zweiten Argument zu erfordern, und ich habe "Survived" übergeben, aber ich erhalte eine Fehlermeldung, wenn test.csv dieses "Survived" nicht hat. Ich habe.
Exception in thread "main" java.lang.IllegalArgumentException: Response Survived not found in file file:/home/tamura/git/tribuo-examples/titanic/test_removed.csv
at org.tribuo.data.csv.CSVLoader.validateResponseNames(CSVLoader.java:286)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:244)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:74)
Während Sie sich über "Warum sollte die Zielvariable in die CSV-Datei aufgenommen werden" beschweren, gibt es keine Hilfe dafür. Fügen Sie daher die Spalte "Überlebt" in der CSV-Datei zur Übermittlung (gender_submission.csv) zu "DataFrame" hinzu. Ich beschloss zu täuschen.
DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame dfSubmission = Csv.loader().load("titanic/gender_submission.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
selectedDfTest = selectedDfTest.hConcat(dfSubmission.dropColumns("PassengerId"));
Csv.save(selectedDfTest, "titanic/test_removed.csv");
ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);
Wenn ich diesmal nachdenklich bin, ist die bekannte Ausnahme ...
Exception in thread "main" java.lang.NumberFormatException: empty String
at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
at java.lang.Double.parseDouble(Double.java:538)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:73)
In test.csv fehlt nur ein Wert in der Spalte "Tarif" ... Dieser fehlende Wert ist ebenfalls ein leeres Zeichen ("" ") und nicht null, sodass" DataFrame.fillNulls () "nicht verwendet werden kann. .. Als Ergebnis der Prüfung des Quellcodes stellte ich fest, dass es möglich ist, leere Zeichen durch den folgenden Schreibstil zu ersetzen.
selectedDfTest = selectedDfTest.convertColumn("Fare", s -> "".equals(s) ? "0": s);
Das vorhergesagte Ergebnis wird jetzt ausgegeben.
[Prediction(maxLabel=(0,0.5474041397777752),outputScores={0=(0,0.5474041397777752)1=(1,0.4525958602222247}), Prediction(maxLabel=(0,0.6969779586356148),outputScores={0=(0,0.6969779586356148)1=(1,0.303022041364385}), Prediction(maxLabel=(1,0.5302004352989867),outputScores={0=(0,0.46979956470101314)1=(1,0.530200435298986}), Prediction(maxLabel=(0,0.52713643586377),outputScores={0=(0,0.52713643586377)1=(1,0.4728635641362}), Prediction(maxLabel=(0,0.5071805368465395),outputScores={0=(0,0.5071805368465395)1=(1,0.492819463153460}), Prediction(maxLabel=(0,0.5134002908191431),outputScores={0=(0,0.5134002908191431)1=(1,0.4865997091808569}),
・ ・ ・
Sie müssen es lediglich zur Übermittlung in einer CSV-Datei speichern. Es ist einfach mit Python zu implementieren, aber es ist nicht überraschend einfach mit DFLib oder Tribuo zu implementieren, und ich habe mir die Dokumentation und den Quellcode eine Weile angesehen, aber es kann einige Zeit dauern, also der Java-Standard Ich habe es wie folgt in die API implementiert.
AtomicInteger counter = new AtomicInteger(891);
StringBuilder sb = new StringBuilder();
predicts.stream().forEach(p -> sb.append(String.valueOf(counter.addAndGet(1) + "," + p.getOutput().toString().substring(1,2)) + "\n"));
try (FileWriter fw = new FileWriter("titanic/submission.csv");){
fw.write("PassengerId,Survived\n");
fw.write(sb.toString());
} catch (IOException ex) {
ex.printStackTrace();
}
Da das zweite Zeichen von "p.getOutput (). ToString ()" der vorhergesagte Wert (0 oder 1) ist, ist es eine schlechte Implementierung, es herauszunehmen und mit "java.io.FileWriter" zu schreiben. Übrigens ist "Zähler" 892 bis 1309 der "Passagier-ID", die in der eingereichten Datei enthalten ist.
Nachdem die CSV-Datei zur Übermittlung ausgegeben wurde, laden Sie sie auf die Website von Kaggle hoch und überprüfen Sie die Punktzahl.
Die Punktzahl auf Kaggle betrug 0,56220. Es ist niedrig, aber vorerst konnte ich es hochladen.
Dieses Mal habe ich nur das Minimum versucht, eine CSV-Datei zur Übermittlung auszugeben. Beim nächsten Mal möchte ich überprüfen, ob Daten auf dem Jupyter-Notizbuch visualisiert, vorverarbeitet, optimiert, zusammengesetzt usw. werden können.
In dieser Implementierung wurde die CSV-Datei mit "CSVLoader" geladen. Ich habe es als selbstverständlich angesehen, diese Klasse beim Laden von CSV-Dateien zu verwenden. Laut dem Tribuo-Entwickler ist es in diesem Fall jedoch angemessen, "RowProcessor" und "CSVDataSource" anstelle von "CSVLoader" zu verwenden. Wenn Sie Folgendes implementieren, erhalten Sie das gleiche Ergebnis wie im obigen Code.
Tokenizer tokenizer = new BreakIteratorTokenizer(Locale.US);
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Survived","0",labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("Pclass", new DoubleFieldProcessor("Pclass"));
fieldProcessors.put("SibSp", new DoubleFieldProcessor("SibSp"));
fieldProcessors.put("Parch", new DoubleFieldProcessor("Parch"));
fieldProcessors.put("Fare", new DoubleFieldProcessor("Fare"));
RowProcessor<Label> rp = new RowProcessor<>(responseProcessor,fieldProcessors);
Path path = Paths.get("titanic/train.csv");
CSVDataSource<Label> source = new CSVDataSource<>(path,rp,true);
TrainTestSplitter dataSplitter = new TrainTestSplitter<>(source, 0.7, 1L);
Durch die Verwendung von "Identity Processor" für die Textzeichenfolge scheint es außerdem, dass die One-Hot-Codierung automatisch durchgeführt wird. Wenn Sie es beispielsweise auf die Spalte "Geschlecht" anwenden, die zwei Werte hat: "männlich" und "weiblich",
fieldProcessors.put("Sex", new IdentityProcessor("Sex"));
Die Spalten "Sex @ männlich" und "Sex @ weiblich" mit einem Wert von 0/1 werden hinzugefügt (obwohl eine Spalte ausreicht, da zwei Auswahlmöglichkeiten bestehen). Dies erhöht auch die Genauigkeit auf "0,746".
Wenn Sie es auf die Spalte "Eingeschifft" anwenden, die drei Werte hat: "C", "S" und "K",
fieldProcessors.put("Embarked", new IdentityProcessor("Embarked"));
Ich möchte sagen, dass die Genauigkeit weiter zunehmen wird, aber in Wirklichkeit wird sie abnehmen. Dies liegt daran, dass die Spalte "Eingeschifft" nur eine Zeile mit einem leeren Wert enthält, was dazu führt, dass sich die Anzahl der Spalten mit Trainingsdaten und Verifizierungsdaten verschiebt und nicht ordnungsgemäß funktioniert. Ich möchte den leeren Wert durch den häufigsten Wert "S" ersetzen, aber es ist ein schmerzhafter Ort, der wie Python nicht schnell realisiert werden kann ... das werde ich beim nächsten Mal wieder aufnehmen.
Meine Antwort auf die Frage "Kann ich in Kaggle nur mit Java kämpfen?" Ist "zu diesem Zeitpunkt ziemlich schwierig" (obwohl ich die Schlussfolgerung von Anfang an kannte). Da Tribuo gerade veröffentlicht wurde, gibt es im Internet nur sehr wenige Informationen, und es gibt nicht genügend Anleitungen und Javadocs. Daher dauert die Untersuchung sehr lange. Und erstens reicht die Funktion nicht aus. Es gibt nicht einmal eine Methode, die den Nullwert in den Daten mit einem Median füllt.
Die aktuelle Situation ist, dass Dinge, die mit Python einfach erledigt werden können, nicht einfach erledigt werden können, aber umgekehrt fehlen viele Funktionen, sodass ich denke, dass es viele Möglichkeiten gibt, einen Beitrag zu leisten (Pull-Anfrage). Wenn Sie ein Java-Programmierer sind, der sich für maschinelles Lernen interessiert, warum tragen Sie nicht zu diesem Projekt bei und studieren? Ich fange gerade an, es zu berühren, also werde ich etwas mehr untersuchen, um zu sehen, wie viel ich tun kann.
Recommended Posts