10 methods to improve the accuracy of BERT

Introduction

It has become commonplace to finetun and use BERT in natural language processing tasks. It is expected that there will be an increasing number of scenes in which we want to improve accuracy as much as possible when conducting competitions such as Kaggle or projects with strict accuracy requirements. Therefore, we will summarize the accuracy improvement methods. A classification task is assumed as a task.

Character number adjustment

You can enter up to 512 words in the learned BERT. Therefore, special measures are required when using text with 512 words or more. Check it out as changes in the processing method here often contribute to improving accuracy.

As an example, consider getting 6 words from the following text (with 1 word as the punctuation mark). `I am a cat / is / is /. / there is no name yet / . ``

  1. Head-Tail I am a cat / is / is /. / Name / is / `not yet / not /. ``

From [How to Fine-tune BERT for Text Classification] how_to_bert. Get words from both the beginning and the end. The figure above shows how to use the first 3 words and the last 3 words. It's easy to implement, has high performance, and is a popular technique in Kaggle. How many words should be taken before and after is on a case-by-case basis.

  1. Random I am a cat / is / is /. / Name / is / not yet /. I am a cat / is / is /. / Name / is / not yet / not `/.

Get words in succession from anywhere. You can expect an effect like Augmentation by changing the acquisition location for each Epoch. However, it does not seem to be as accurate as the Head-Tail method. How about combining it with TTA (Test Time Augmentation)?

  1. Sliding Window I am a cat / is / is /. / Name / is / not yet / not /. I am a cat / de / is /. / Name / is / Not yet / Not /. I am a cat / is / is /. `/ name / is / not yet / not /. ``

This is a technique often used in Google's Natural Question Dataset, such as A BERT Baseline for the Natural Questions. The figure above shows the case where the words are shifted by 3 words. Its strength is that it can completely cover the data. The disadvantage is that if the data has a large number of words, the learning data will be too large. It is used when it is important to use all words, such as in QA tasks, but it can also help improve the accuracy of classification tasks.

Additional meta information

Consider entering questions and answers and titles such as:

Title: About President Trump Question: Where is President Trump from? Answer: New York.

4. Add separator

[CLS] About President Trump [NEW_SEP] Where is President Trump from? [SEP] New York. [SEP]

From Google QUEST Q & A Labeling 19th solution. If there are two sentences in BERT, it is okay to separate them with the [SEP] tag, but it does not support more sentences. Therefore, by defining a token with an appropriate name and using it as a separator, such as [NEW_SEP], you can express sentence breaks. Such tokens can be added using tokenizer.add_special_tokens. Also, in the case of the English version of BERT, there are unused tokens from [unused0] to [unused993], so you can use that as well.

5. Add category information

[CLS] [CATEGORY_0] Where is President Trump's birthplace? [SEP] New York. [SEP]

Jigsaw Unintended Bias in Toxicity Classification 1ST PLACE SOLUTION. Let's say you solve the task of determining if the above sentence is a proper question-answer pair. Question answering logs are often categorized, so you may want to add them to the features. In that case, you can improve the accuracy by defining new tokens [CATEGORY_0] to [CATEGORY_n] (n is the number of categories) and incorporating them into the text as described above.

It is also effective to perform categorization as a subtask using the [CATEGORY_0] vector as a feature.

Model building

The regular model of BERT consists of 12 layers of submodules. When finetuning BERT, the default implementation is to feature the vector at the beginning [CLS] of the output of the last layer. That is often sufficient for accuracy, but you can expect a slight improvement in accuracy by using other features.

6. Use 4 layers from the last layer

last_4_layers.png From [How to Fine-tune BERT for Text Classification] how_to_bert. We aim to improve the accuracy of finetuning tasks by combining the four [CLS] vectors from the bottom of the 12 layers. The vector is finally made into a 768-dimensional vector (first-order tensor) using average pooling, max pooling, concat, etc.

  1. Learnable Weighted Sum weighted_sum.png From Google QUEST Q & A Labeling 1st place solution. Calculate the weighted sum of the [CLS] vectors of all BERT Layers by setting the trainable weights in the model. Simply averaging all layers is a powerful method, but it is a further development of it.

8. Add CNN layer

bert_with_cnn.png Identifying Russian Trolls on Reddit with Deep Learning and BERT Word Embeddings. Entering the vector of all words into the CNN, not just the vector of [CLS], is one of the powerful techniques. Calculate one-dimensional convolution for up to 512 sequence lengths as shown. After calculating the convolution, max pooling or average pooling can extract features with a number of dimensions and a number of filters, so enter them in Dense. Compared to Attention, CNN can aggregate the characteristics of surrounding words, so you can improve accuracy by combining them. It is also effective to combine it with LSTM as well as CNN.

Learning

9. Fix BERT weight

model_params = list(model.named_parameters())

#Fix BERT weight
params = [p for n, p in model_params if not "bert" in n]
optimizer = AdamW(params , lr=2e-5)

#Release the weight fixing of BERT
params = [p for n, p in model_params if "bert" in n]
optimizer.add_param_group({'params': params })

From Google QUEST Q & A Labeling 19th solution. Similar to image-based trained models, BERT may improve accuracy by fixing weights and training only task-dependent layers. In the 19th solution, only the first epoch is fixed and all layers are learned later. Above, we have described the code that fixes the BERT weight and starts learning, and the code that can be used to release the fixing from the middle and restart learning.

10. Change the learning rate of BERT and other layers

model_params = list(model.named_parameters())

bert_params = [p for n, p in model_params if "bert" in n]
other_params = [p for n, p in model_params if not "bert" in n]

params = [
    {'params': bert_params, 'lr': params.lr},
    {'params': other_params, 'lr': params.lr * 500}
]

From Google QUEST Q & A Labeling 1st place solution. Adopting different learning rates is as effective as the trained model of the image system. In the 1st place solution, task-specific layers are learned at a learning rate 500 times higher than usual. The code at that time is shown above.

in conclusion

We have introduced techniques that may improve accuracy in BERT classification tasks. However, I couldn't show how much the accuracy would improve, so I'd like to compare with an appropriate data set. In addition to the methods listed above, there are still more methods for improving accuracy, so we will continue to investigate.

Recommended Posts

10 methods to improve the accuracy of BERT
Try to improve the accuracy of Twitter like number estimation
I tried how to improve the accuracy of my own Neural Network
Supplement to the explanation of vscode
To improve the reusability and maintainability of workflows created with Luigi
I tried to improve the efficiency of daily work with Python
The story of trying to reconnect the client
Script to change the description of fasta
How to check the version of Django
The story of adding MeCab to ubuntu 16.04
The story of pep8 changing to pycodestyle
Various methods to numerically create the inverse function of a certain function Introduction
I measured 6 methods to get the index of the maximum value (minimum value) of the list
I tried to compare the accuracy of Japanese BERT and Japanese Distil BERT sentence classification with PyTorch & Introduction of BERT accuracy improvement technique
How to calculate the volatility of a brand
How to find the area of the Voronoi diagram
About the accuracy of Archimedean circle calculation method
Setting to output the log of cron execution
The inaccuracy of Tensorflow was due to log (0)
Set the range of active strips to the preview range
Organize the meaning of methods, classes and objects
I tried to touch the API of ebay
I tried to correct the keystone of the image
Change the decimal point of logging from, to.
Consider improving the accuracy of VAE abnormality detection
Aim to improve prediction accuracy with Kaggle / MNIST (1. Create CNN according to the tutorial)
To get the path of the currently running python.exe
I want to customize the appearance of zabbix
From the introduction of pyethapp to the execution of contract
Try to simulate the movement of the solar system
The story of moving from Pipenv to Poetry
I tried to predict the price of ETF
I tried to vectorize the lyrics of Hinatazaka46!
While maintaining the model structure of image classification (mnist), attach an autoencoder to improve the accuracy of end to end. [tensorflow, keras, mnist, autoencder]
Various methods to numerically create the inverse function of a certain function Part 1 Polynomial regression
[Implementation explanation] How to use the Japanese version of BERT in Google Colaboratory (PyTorch)
python beginners tried to predict the number of criminals
The wall of changing the Django service from Python 2.7 to Python 3
Template of python script to read the contents of the file
How to get the number of digits in Python
A memo to visually understand the axis of pandas.Panel
I want to grep the execution result of strace
I tried to summarize the basic form of GPLVM
Add information to the bottom of the figure with Matplotlib
Try to solve the problems / problems of "Matrix Programmer" (Chapter 1)
How to visualize the decision tree model of scikit-learn
[python] option to turn off the output of click.progressbar
Try to estimate the number of likes on Twitter
Get to know the feelings of gradient boosting trees
I want to fully understand the basics of Bokeh
Link the mouse to the gyro of the Nintendo Switch Procon
Try to get the contents of Word with Golang
Steps to calculate the likelihood of a normal distribution
[Blender] How to dynamically set the selection of EnumProperty
I tried to visualize the spacha information of VTuber
Set the specified column of QTableWidget to ReadOnly StyledItemDelegate
I tried to erase the negative part of Meros
[Python] Summary of how to specify the color of the figure
How to hit the document of Magic Function (Line Magic)
How to access the global variable of the imported module
14 quizzes to understand the surprisingly confusing scope of Python