This is a continuation of document classification using the Microsoft Cognitive Toolkit (CNTK).
In Part2, the document data prepared in Part1 will be used to classify documents by CNTK. It is assumed that CNTK and NVIDIA GPU CUDA are installed.
Natural Language: Doc2Vec Part1 --livedoor NEWS Corpus prepared training data and verification data.
In Part2, we will create a Doc2Vec model and classify sentences.
Doc2Vec Doc2Vec [1] [2] [3] is an extension of Word2Vec. The Doc2Vec implemented this time is a simple model that averages the output of the embedded layer of all words contained in one document and classifies which category the document belongs to.
The default value for each parameter uses the CNTK default settings. In most cases, it has a uniform distribution of Glorot [4].
Word2Vec adopted Sampled Softmax [5] to speed up the output layer to predict words, but this document Since the classification is 9 categories, I used the normal Softmax function and Cross Entropy Error.
Adam [6] was used as the optimization algorithm. Adam's learning rate is 0.01, hyperparameters $ β_1 $ are set to 0.9, and $ β_2 $ is set to the default value of CNTK.
Model training performed 10 Epoch by mini-batch learning.
-CPU Intel (R) Core (TM) i7-6700K 4.00GHz ・ GPU NVIDIA GeForce GTX 1060 6GB
・ Windows 10 Pro 1909 ・ CUDA 10.0 ・ CuDNN 7.6 ・ Python 3.6.6 ・ Cntk-gpu 2.7 ・ Pandas 0.25.0
The training program is available on GitHub.
doc2vec_training.py
Training loss and error
The figure below is a visualization of the loss function and false recognition rate logs during training. The graph on the left shows the loss function, the graph on the right shows the false recognition rate, the horizontal axis shows the number of epochs, and the vertical axis shows the value of the loss function and the false recognition rate, respectively.
Validation accuracy and confusion matrix
When the performance was evaluated using the verification data that was separated when preparing the data in Part 1, the following results were obtained.
Accuracy 90.00%
The figure below is a visualization of the mixed matrix of the verification data. The column direction is the prediction and the row direction is the correct answer.
I tried to find out which words in a sentence are important when classifying documents using back propagation of gradients.
dojujo-tsushin
1 single woman
2 Lady Girls
3 Saori Abe
4 woman
5 never get married
6 age
7 me
8 married
9 Values
10 copies
Words about women are emphasized in articles in the German newsletter.
it-life-hack
1 smartphone
2 services
3 services
4 apps
5 google
6 google
7 google
8 google
9 google
10 google
Words about IT are emphasized in IT Lifehack articles.
sports-watch
1 training
2 number of places
3 clubs
4 clubs
5 home
6 Top
7 Vision
8 Yoshida
9 Yoshida
10 Yoshida
Sports Watch articles emphasize sports-related words.
Natural Language : Doc2Vec Part1 - livedoor NEWS Corpus Natural Language : Word2Vec Part2 - Skip-gram model
Recommended Posts