[PyTorch] How to use BERT --Fine tuning Japanese pre-trained models to solve classification problems

Introduction

BERT is updating SOTA with various tasks of natural language processing, but the one published by Google head family on Github is based on Tensorflow. It is implemented. People who use PyTorch want to use the PyTorch version, but since I have not made the PyTorch version, use the one made by HuggingFace, but we develop I'm not involved in, so ask them for more information! And QA.

BERT made by HuggingFace, but there was no Japanese pre-trained models until December 2019. Therefore, I could easily try it in English, but in Japanese I had to prepare pre-trained models myself. However, in December 2019, Japanese pre-trained models were finally added. https://huggingface.co/transformers/pretrained_models.html

  1. bert-base-japanese
  2. bert-base-japanese-whole-word-masking
  3. bert-base-japanese-char
  4. bert-base-japanese-char-whole-word-masking

Four models can be used in Created by Inui Laboratory of Tohoku University. Unless there are special circumstances, it is better to use the second bert-base-japanese-whole-word-masking. In the normal version and the Whole Word Masking version, the Whole Word Masking version seems to tend to have slightly higher accuracy of fine-tuned tasks [^ 1].

This makes it easier to try the PyTorch version of BERT in Japanese.

What is BERT?

The mechanism of BERT has already been introduced in various blogs and books, so I will omit a detailed explanation. Simply put

--Create pre-trained models from a large number of unsupervised corpora --Pre-learn by solving two types of language tasks, Masked Language Model and Next Sentence Predicition --Fine tuning pre-trained models to solve the task

It will be the flow of processing. Creating Pre-trained models requires a large amount of computer resources and time, but one of the points of BERT is that it can improve the accuracy of tasks.

Japanese Pre-trained models

First, check the accuracy of the pre-learned Japanese pre-trained models. This time we will check the accuracy of the Masked Language Model. A simple explanation of the Masked Language Model is to mask a word in a sentence and predict the masked word.

Using BertJapaneseTokenizer and BertForMaskedLM, you can write: Predict the word by masking "soccer" in the sentence "Watch a soccer match on TV."

import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

# Load pre-trained tokenizer
tokenizer = BertJapaneseTokenizer.from_pretrained('bert-base-japanese-whole-word-masking')

# Tokenize input
text = 'Watch a soccer match on TV.'
tokenized_text = tokenizer.tokenize(text)
# ['TV set', 'so', 'Football', 'of', 'match', 'To', 'to see', '。']

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 2
tokenized_text[masked_index] = '[MASK]'
# ['TV set', 'so', '[MASK]', 'of', 'match', 'To', 'to see', '。']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# [571, 12, 4, 5, 608, 11, 2867, 8]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
# tensor([[ 571,   12,    4,    5,  608,   11, 2867,    8]])

# Load pre-trained model
model = BertForMaskedLM.from_pretrained('bert-base-japanese-whole-word-masking')
model.eval()

# Predict
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0][0, masked_index].topk(5) #Extract the top 5 forecast results

# Show results
for i, index_t in enumerate(predictions.indices):
    index = index_t.item()
    token = tokenizer.convert_ids_to_tokens([index])[0]
    print(i, token)

The execution result of the above program is as follows. "Soccer" has appeared in 3rd place, and other words are likely to be correct as Japanese. It is thought that the reason why "cricket" and major league team names, which are not so familiar in Japan, appear is because they learned in advance from Wikipedia data.

0 cricket
1 tigers
2 soccer
3 Mets
4 cubs

From the above, it was confirmed that the pre-trained models were correctly pre-trained. Next, fine-tune and solve the task based on this pre-trained models.

Fine tuning with BERT

Modified the source code to work with Japanese original data

HuggingFace's GitHub has some examples of fine tuning to solve tasks. However, these are for English datasets, and none are for Japanese datasets [^ 2].

Therefore, I will modify the existing source code so that it will work with the original Japanese data. Assuming text classification, which is a basic task in natural language processing, the source code used for GLUE text classification is targeted. And

  1. transformers/data/processors/glue.py
  2. transformers/data/metrics/__init__.py

Modify the two programs.

Caution Note that this is not a file downloaded by git clone etc., but you need to change the file in the installation directory. For example, if you are using venv, the installation directory will be [venv directory] /lib/python3.7/site-packages/transformers.

  1. transformers/data/processors/glue.py This is the part that reads the training data (train.tsv) and verification data (dev.tsv). Add the task ʻoriginal to glue_tasks_num_labels, glue_processors, glue_output_modes, and then add the class ʻOriginalProcessor as follows:
glue_tasks_num_labels = {
    "cola": 2,
    "mnli": 3,
    "mrpc": 2,
    "sst-2": 2,
    "sts-b": 1,
    "qqp": 2,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
    "original": 2, #add to
}

glue_processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor,
    "wnli": WnliProcessor,
    "original": OriginalProcessor, #add to
}

glue_output_modes = {
    "cola": "classification",
    "mnli": "classification",
    "mnli-mm": "classification",
    "mrpc": "classification",
    "sst-2": "classification",
    "sts-b": "regression",
    "qqp": "classification",
    "qnli": "classification",
    "rte": "classification",
    "wnli": "classification",
    "original": "classification", #add to
}
class OriginalProcessor(DataProcessor):
    """Processor for the original data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
 
    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
 
    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            #Uncomment any header lines in the TSV file
            # if i == 0:
            #     continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

The training data and verification data are

  1. Text
  2. Label

I'm assuming a TSV file consisting of two columns.

train.tsv


It was interesting 0
It was fun 0
Was boring 1
I was sad 1

dev.tsv


Enjoyed 0
It was painful 1

The above program assumes binary classification, but for multi-value classification, modify the number and values of labels as appropriate.

  1. transformers/data/metrics/__init__.py This is the part that calculates the accuracy using the verification data. Just add the case of task_name ==" original " in the conditional expression as follows:
    def glue_compute_metrics(task_name, preds, labels):
        assert len(preds) == len(labels)
        if task_name == "cola":
            return {"mcc": matthews_corrcoef(labels, preds)}
        elif task_name == "sst-2":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mrpc":
            return acc_and_f1(preds, labels)
        elif task_name == "sts-b":
            return pearson_and_spearman(preds, labels)
        elif task_name == "qqp":
            return acc_and_f1(preds, labels)
        elif task_name == "mnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mnli-mm":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "qnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "rte":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "wnli":
            return {"acc": simple_accuracy(preds, labels)}
        #add to
        elif task_name == "original":
            return {"acc": simple_accuracy(preds, labels)}
        else:
            raise KeyError(task_name)

Fine tuning to solve the classification problem

Now that the original Japanese data works, all you have to do is fine-tune and solve the classification problem. It only runs the following command: Put the training data and verification data files under data / original /.

$ python examples/run_glue.py \
    --data_dir=data/original/ \
    --model_type=bert \
    --model_name_or_path=bert-base-japanese-whole-word-masking \
    --task_name=original \
    --do_train \
    --do_eval \
    --output_dir=output/original

If you execute the above command and finish without any problem, the following log will be output. The value of acc is 1.0, and you can see that the two verification data are correctly classified.

01/18/2020 17:08:39 - INFO - __main__ -   Saving features into cached file data/original/cached_dev_bert-base-japanese-whole-word-masking_128_original
01/18/2020 17:08:39 - INFO - __main__ -   ***** Running evaluation  *****
01/18/2020 17:08:39 - INFO - __main__ -     Num examples = 2
01/18/2020 17:08:39 - INFO - __main__ -     Batch size = 8
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.59it/s]
01/18/2020 17:08:40 - INFO - __main__ -   ***** Eval results  *****
01/18/2020 17:08:40 - INFO - __main__ -     acc = 1.0

And you can see that the model file is created under ʻoutput / original /`.

$ find output/original 
output/original
output/original/added_tokens.json
output/original/tokenizer_config.json
output/original/special_tokens_map.json
output/original/config.json
output/original/training_args.bin
output/original/vocab.txt
output/original/pytorch_model.bin
output/original/eval_results.txt

in conclusion

I introduced how to classify Japanese texts using the PyTorch version of BERT. By modifying other source code, you can perform tasks such as text generation and question answering as well as text classification.

Until now, running BERT in Japanese using PyTorch had a high hurdle, but I think that the hurdle has become very low with the release of pre-trained models in Japanese. By all means, please try the PyTorch version of BERT with Japanese tasks.

Reference article

https://techlife.cookpad.com/entry/2018/12/04/093000 http://kento1109.hatenablog.com/entry/2019/08/23/092944

Recommended Posts

[PyTorch] How to use BERT --Fine tuning Japanese pre-trained models to solve classification problems
[PyTorch] Introduction to Japanese document classification using BERT
[Implementation explanation] How to use the Japanese version of BERT in Google Colaboratory (PyTorch)
[PyTorch] Introduction to document classification using BERT
How to use Japanese with NLTK plot
Basics of PyTorch (1) -How to use Tensor-
I tried to implement sentence classification & Attention visualization by Japanese BERT in PyTorch
How to use xgboost: Multi-class classification with iris data
How to use Spacy Japanese model in Google Colaboratory
How to use xml.etree.ElementTree
How to use Python-shell
How to use tf.data
How to use virtualenv
How to use Seaboan
How to use image-match
How to use shogun
How to use Virtualenv
How to use numpy.vectorize
How to use pytest_report_header
How to use partial
How to use Bio.Phylo
How to use SymPy
How to use x-means
How to use WikiExtractor.py
How to use virtualenv
How to use Matplotlib
How to use iptables
How to use numpy
How to use TokyoTechFes2015
How to use venv
How to use dictionary {}
How to use Pyenv
How to use list []
How to use python-kabusapi
How to use OptParse
How to use return
How to use dotenv
How to use pyenv-virtualenv
How to use Go.mod
How to use imutils
How to use import
13th Offline Real-time How to Solve Writing Problems in Python