[JAVA] Try using GloVe with Deeplearning4j

I'm going to use GloVe with Deeplearning4j, a Java library.

Premise

Corpus

Please prepare the corpus you want to study in advance. In the case of a Japanese corpus, write it in separate words. When writing in separate words, it may be better to change the verbs to the basic form (original form).

Learning

Keep the corpus text file as ** input.txt **. Save the created model as ** model.txt **.

ModelBuild.java


import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;

import java.io.*;

public class ModelBuild {
    public static void main( String[] args ) throws Exception{
    	
     	//Read corpus file
    	System.out.println("Reading data...");
        File inputFile = new File("input.txt");
    	
        //Read as text data class
        SentenceIterator iter = new BasicLineIterator(inputFile);
        
        //Create a tokenizer (word split) class
        System.out.println("Create a tokenizer...");
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());
        
        //Creating a model
        System.out.println("Creating a model...");
        Glove glove = new Glove.Builder()
        		.iterate(iter) //Sentence data class
        		.tokenizerFactory(t) //Word decomposition class
        		.alpha(0.75) //Parameters in the exponent of the weighting function
        		.learningRate(0.1) //Initial learning rate
        		.epochs(25) //Number of iterations on the training corpus during training
        		.layerSize(300) //Number of dimensions of vector
        		.maxMemory(2) //Maximum memory usage
        		.xMax(100) //Weight function cutoff
        		.batchSize(1000) //Number of words to learn in one mini-batch
        		.windowSize(10) //Window size
        		.shuffle(true)
        		.symmetric(true)
        		.build();
        
        //Learning
        System.out.println("I'm learning...");
        glove.fit();
        
        //Save model
        System.out.println("Saving the model...");
        WordVectorSerializer.writeWordVectors(glove, "model.txt");
        
        System.out.println("The program is over");
    }
}

Evaluation

Evaluation.java


import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.Collection;

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;

public class Evaluation {

	public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
		//Load model file
    	System.out.println("Loading model file...");
    	File inputFile = new File(args[0]);
    	WordVectors vec = WordVectorSerializer.loadTxtVectors(inputFile);
    	
    	//Display the top 10 similar words for the word (for example, "weather")
    	System.out.println("Top 10 similar words...");
    	String  word        = "weather";
        int     ranking     = 10;
        Collection<String>  similarTop10    = vec.wordsNearest( word , ranking );
        System.out.println( String.format( "Similar word to 「%s」 is %s" , word , similarTop10 ) );
        
        //Show cosine similarity (eg "sunny" and "rain")
        System.out.println( "Show cosine similarity..." );
        String  word1       = "Sunny";
        String  word2       = "rain";
        double  similarity  = vec.similarity( word1 , word2 );
        System.out.println( String.format( "The similarity between 「%s」 and 「%s」 is %f" , word1 , word2 , similarity ) );
	}
}

The page that I referred to the code

Recommended Posts

Try using GloVe with Deeplearning4j
Try using view_component with rails
Try using Redis with Java (jar)
Try using libGDX
Try using Maven
Try using powermock-mockito2-2.0.2
Try using GraalVM
Try using jmockit 1.48
Try using sql-migrate
Try using Spring Boot with VS Code
Try using SwiftLint
Try using Log4j 2.0
Try using Kong + Konga with Docker Compose.
Try using the Wii remote with Java
Try using GPS receiver kit with RaspberryPi3 (Ruby)
Try using S3Proxy with Microsoft Azure Blob Storage
Try using another Servlet container Jetty with Docker
Try DI with Micronaut
Try using Axon Framework
Try create with Trailblazer
Try using JobScheduler's REST-API
Try using java.lang.Math methods
Try using PowerMock's WhiteBox
Using Pair with OpenJDK
Try WebSocket with jooby
Try using Talend Part 2
Try WildFly with Docker
Try using Talend Part 1
Try using F # list
Try using each_with_index method
Try using Spring JDBC
Try using DI container with Laravel and Spring Boot
Try using OpenID Connect with Keycloak (Spring Boot application)
Try to work with Keycloak using Spring Security SAML (Spring 5)
Try using RocksDB in Java
Try DB connection with Java
Try scraping using java [Notes]
Try using Cocoa from Ruby
Try using letter_opener_web for inquiries
Try gRPC with Java, Maven
Japaneseize using i18n with Rails
[Swift] Try using Collection View
Using Mapper with Java (Spring)
Try using IntelliJ IDEA once
Try reading XML with JDOM
Try using Spring Boot Security
Try using gRPC in Ruby
[Rails] Try using Faraday middleware
[Processing] Try using GT Force.
[Programming Encyclopedia] §2 Try using Ruby
People using docker Try using docker-compose
Using PlantUml with Honkit [Docker]
Try document database operations using X DevAPI with MySQL Connector / J 8.0.15
Part 1: Try using OAuth 2.0 Login supported by Spring Security 5 with Spring Boot
Folder compression with Scala. Using java.util.zip.ZipOutputStream.
Try using Redmine on Mac docker
I tried using JOOQ with Gradle
Try running cloudera manager with docker
Try implementing recaptcha with Jetty embedded.
Try manipulating PostgreSQL arrays with JDBC
Try to imitate marshmallows with MiniMagick