J'ai essayé d'implémenter la classification des phrases et la visualisation de l'attention par le japonais BERT avec PyTorch

introduction

Grâce aux transformateurs de huggingface, le modèle japonais BERT peut désormais être manipulé très facilement à l'aide de PyTorch.

De nombreuses personnes ont déjà publié des articles sur le BERT japonais en utilisant un visage / des transformateurs étreignant, mais j'ai décidé de publier un article après avoir étudié.

référence

[Apprenez en créant! Deep learning by PyTorch](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3 % 81% 8C% E3% 82% 89% E5% AD% A6% E3% 81% B6-PyTorch% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E7% 99% BA% E5 % B1% 95% E3% 83% 87% E3% 82% A3% E3% 83% BC% E3% 83% 97% E3% 83% A9% E3% 83% BC% E3% 83% 8B% E3% 83 Publié par l'auteur de% B3% E3% 82% B0-% E5% B0% 8F% E5% B7% 9D% E9% 9B% 84% E5% A4% AA% E9% 83% 8E / dp / 4839970254) Les articles suivants qui ont été publiés sont extrêmement faciles à comprendre. Il explique poliment les endroits où les débutants BERT comme moi sont susceptibles de rester coincés.

En référence aux livres ci-dessus et aux articles Qiita (ou presque à copier), je vais également essayer de mettre en œuvre la classification des phrases par BERT. J'aborderai également la visualisation par Attention. Pour ceux qui veulent classer des phrases en utilisant BERT pour le moment, et veulent voir la visualisation de Attention. La théorie de BERT ne touche pas du tout à l'histoire.

Problème de réglage

Traitez le corpus d'actualités de livingoor comme des données de validation comme d'habitude. Le texte des news de livingoor est utilisé dans l'article de référence, mais ce n'est pas intéressant s'il est exactement le même, donc le titre du corpus de news de livingoor est le même que article écrit dans le passé. Je vais essayer de classer les phrases en utilisant uniquement.

la mise en oeuvre

Il est implémenté sur Google Colab ainsi que l'article de référence.

Préparation des données

Tout d'abord, montez Google Drive sur colab

from google.colab import drive
drive.mount('/content/drive')

Obtenez le corpus d'actualités de livingoor en vous référant à ici. Enregistrez l'ensemble de données avec le titre et la catégorie du corpus d'actualités Liveoor extrait dans Google Drive en tant que DataFrame et stockez-le dans Google Drive. Après stockage, l'état de vérification du contenu des données est le suivant.

import pickle
import pandas as pd

#Emplacement de stockage de l'ensemble de données
drive_dir = "drive/My Drive/Colab Notebooks/livedoor_data/"

with open(drive_dir + "livedoor_title_category.pickle", 'rb') as f:
  livedoor_data = pickle.load(f)

livedoor_data.head()
#title	category
#0 Internet confortable même à l'étranger! KDDI, "au Wi-Élargir le service Fi SPOT-life-hack
#1 [Fonction spéciale/VOYAGE] Dans un pays arabe passionnant et doux (4)/8)	livedoor-homme
#2 Twitter pour femme célibataire, façon surprenante de profiter de dokujo-tsushin
#3 L'histoire de la construction de la pyramide en 20 ans est un film de mensonge-enter
#4 Ayame Goriki présente un gâteau au chocolat fait à la main avec un film "beaucoup d'amour"-enter

Identifions la catégorie.

#Obtenir une liste de catégories à partir d'un ensemble de données
categories = list(set(livedoor_data['category']))
print(categories)
#['topic-news', 'movie-enter', 'livedoor-homme', 'it-life-hack', 'dokujo-tsushin', 'sports-watch', 'kaden-channel', 'peachy', 'smax']

#Créer un dictionnaire d'identifiants pour une catégorie
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))
print(id2cat)
print(cat2id)
#{0: 'topic-news', 1: 'movie-enter', 2: 'livedoor-homme', 3: 'it-life-hack', 4: 'dokujo-tsushin', 5: 'sports-watch', 6: 'kaden-channel', 7: 'peachy', 8: 'smax'}
#{'topic-news': 0, 'movie-enter': 1, 'livedoor-homme': 2, 'it-life-hack': 3, 'dokujo-tsushin': 4, 'sports-watch': 5, 'kaden-channel': 6, 'peachy': 7, 'smax': 8}

#Ajout de la colonne ID de catégorie à DataFrame
livedoor_data['category_id'] = livedoor_data['category'].map(cat2id)

#Mélangez juste au cas où
livedoor_data = livedoor_data.sample(frac=1).reset_index(drop=True)

#Faire de l'ensemble de données uniquement des colonnes de titre et d'ID de catégorie
livedoor_data = livedoor_data[['title', 'category_id']]
livedoor_data.head()
#title	category_id
#0 Nainai Okamura refuse de demander l'apparition du numéro spécial AKB "Qui apparaît dans un tel endroit ..." 0
#1	C-"Star Wars in Concert" où 3PO présente des scènes célèbres débarquées au Japon 1
#2 Livrer la scène de voyeur!?Un moment choquant a été découvert qui devrait être la diffusion d'un événement gratuit [Sujet] 6
#3 "À Mitsuhiro Oikawa dans le dernier épisode de mon partenaire," Traitement impitoyable "" et à la femme elle-même 0
#4 Au-dessus de Hasebe et Kazu? Il y a 5 joueurs surprenants dans "Les athlètes qui aiment les élèves du primaire"

Étant donné que torchtext est utilisé pour le prétraitement des données, séparez l'ensemble de données pour l'entraînement et le test et enregistrez-le dans un fichier tsv.

#Divisez en données d'entraînement et données de test
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(livedoor_data, train_size=0.8)
print("Taille des données d'entraînement", train_df.shape[0])
print("Taille des données de test", test_df.shape[0])
#Taille des données d'entraînement 5900
#Taille des données de test 1476

#Enregistrer en tant que fichier tsv
train_df.to_csv(drive_dir + 'train.tsv', sep='\t', index=False, header=None)
test_df.to_csv(drive_dir + 'test.tsv', sep='\t', index=False, header=None)

Installez MeCab et huggingface / transformers

Je l'ai mentionné dans ici, mais il semble qu'une certaine prudence soit requise lors de l'installation de MeCab. À l'heure actuelle, si vous installez pip comme indiqué ci-dessous, cela fonctionne sans erreur.

#Préparer MeCab et les transformateurs
!apt install aptitude swig
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
#Mecab comme indiqué ci-dessous-python3 version 0.996.Si vous ne le réglez pas sur 5, il tombera avec tokezer
# https://stackoverflow.com/questions/62860717/huggingface-for-japanese-tokenizer
!pip install mecab-python3==0.996.5
!pip install unidic-lite #Sans cela, il échouera avec une erreur lors de l'exécution de MeCab
!pip install transformers

Créer un itérateur avec torchtext

Avec tokenizer.encode, vous pouvez exécuter l'écriture de division qui peut être utilisée avec le modèle japonais BERT, et avec tokenizer.convert_ids_to_tokens, vous pouvez convertir la chaîne d'identification divisée en morphologie et sous-mots. Très pratique.

import torch
import torchtext
from transformers.modeling_bert import BertModel
from transformers.tokenization_bert_japanese import BertJapaneseTokenizer

#Déclaré un tokenizer pour le partage de BERT japonais
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

#J'essaierai de le partager.
text = list(train_df['title'])[0]
wakati_ids = tokenizer.encode(text, return_tensors='pt')
print(tokenizer.convert_ids_to_tokens(wakati_ids[0].tolist()))
print(wakati_ids)
print(wakati_ids.size())
#['[CLS]', 'la taille', 'Mais', 'Faible', 'Femme', 'Est', 'mariage', 'À', 'Défavorable', '?', '[SEP]']
#tensor([[   2, 7236,   14, 3458,  969,    9, 1519,    7, 9839, 2935,    3]])
#torch.Size([1, 11])

Le modèle de pré-apprentissage japonais de l'Université de Tohoku qui peut être manipulé à partir du visage serré a jusqu'à 512 nombres premiers morphologiques (nombre de sous-mots) de phrases. Donc, si l'élément de formulaire des données à traiter et que le nombre de sous-mots dépasse 512, spécifiez max_length à 512. Cependant, en ce qui concerne les titres de ce corpus d'actualités livesoor, le nombre maximum est de 76 comme indiqué ci-dessous, donc max_length n'est pas spécifié cette fois.

#La longueur des phrases qui peuvent être traitées par le BERT japonais est de 512, mais la longueur maximale des titres d'actualités en direct est CLS.,76 même avec des jetons SEP
import seaborn as sns
title_length = livedoor_data['title'].map(tokenizer.encode).map(len)
print(max(title_length))
# 76

sns.distplot(title_length)

Créez un itérateur comme celui-ci. Puisque la taille de tokenizer.encode est (1 x longueur de phrase), il est nécessaire de spécifier[0].

#Créer un itérateur de données d'entraînement et de données de test à l'aide de torchtext
def bert_tokenizer(text):
  return tokenizer.encode(text, return_tensors='pt')[0]

TEXT = torchtext.data.Field(sequential=True, tokenize=bert_tokenizer, use_vocab=False, lower=False,
                            include_lengths=True, batch_first=True, pad_token=0)
LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

train_data, test_data = torchtext.data.TabularDataset.splits(
    path=drive_dir, train='train.tsv', test='test.tsv', format='tsv', fields=[('Text', TEXT), ('Label', LABEL)])

#BERT semble utiliser une taille de mini-lot de 16 ou 32, mais le titre de livesoor a une longueur de phrase courte, donc même 32 fonctionnera sur colab.
BATCH_SIZE = 32
train_iter, test_iter = torchtext.data.Iterator.splits((train_data, test_data), batch_sizes=(BATCH_SIZE, BATCH_SIZE), repeat=False, sort=False)

Déclaration du modèle de classification

Vérifions les formats d'entrée et de sortie du BERT japonais appris avant. Le modèle BERT peut être facilement déclaré en une seule ligne comme suit. Trop pratique

from transformers.modeling_bert import BertModel
model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

Vous pouvez voir la structure de BERT en imprimant le modèle lui-même. La sortie est longue, alors gardez-la fermée.

<détails>

Structure du modèle BERT </ summary>

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(32000, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

Comme vous pouvez le voir à partir de ce résultat, il existe un calque Embedding qui transforme d'abord les mots en vecteurs, puis il y a 12 calques Bert. Vous pouvez également confirmer que le nombre de dimensions vectorielles du mot et le nombre de dimensions du calque caché à l'intérieur sont de 768 dimensions.

Vérifions les formats d'entrée et de sortie de BertModel avec référence.

  • https://huggingface.co/transformers/model_doc/bert.html#bertmodel

Le format d'entrée du modèle BERT s'écrit (batch_size, sequence_length). La sortie semble retourner last_hidden_state et pooler_output par défaut, mais le poids Attention semble être obtenu en spécifiant ʻoutput_attentions = True`. Attention renvoie tous les résultats de chacune des 12 attentions multi-têtes dans le BertLayer 12 couches.

#À partir de l'itérateur de données de test créé ci-dessus
batch = next(iter(test_iter))
print(batch.Text[0].size())
# torch.Size([32, 48]) ←(batch_size, sequence_length)

#Sortie pendant la propagation directe BERT_attentions=Vous pouvez obtenir un poids Attention avec True
last_hidden_state, pooler_output, attentions = model(batch.Text[0], output_attentions=True)
print(last_hidden_state.size())
print(pooler_output.size())
print(len(attentions), attentions[-1].size())
#torch.Size([32, 48, 768]) ← (batch_size, sequence_length×hidden_size)
#torch.Size([32, 768])
#12 torch.Size([32, 12, 48, 48]) ← (batch_size, num_heads, sequence_length, sequence_length)

Lors de l'acquisition d'un vecteur de phrase avec BERT, le vecteur du jeton cls au début de chaque vecteur de mot de last_hidden_state est considéré comme un vecteur de phrase et utilisé.

Maintenant que nous avons en quelque sorte compris les formats d'entrée et de sortie du modèle BERT, nous allons construire un modèle qui classe réellement les phrases en utilisant BERT. Comme le fait l'article de référence, je pense qu'il est plus facile de comprendre la structure et il est plus facile de l'étudier en l'implémentant par vous-même au lieu d'utiliser la bibliothèque de classification fournie par huggingface, donc classification Implémentez sans utiliser la bibliothèque.

from torch import nn
import torch.nn.functional as F
from transformers.modeling_bert import BertModel

class BertClassifier(nn.Module):
  def __init__(self):
    super(BertClassifier, self).__init__()
    self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
    #Le calque caché de BERT a 768 dimensions,9 catégories d'actualités Liveoor
    self.linear = nn.Linear(768, 9)
    #Traitement d'initialisation du poids
    nn.init.normal_(self.linear.weight, std=0.02)
    nn.init.normal_(self.linear.bias, 0)

  def forward(self, input_ids):
    # last_hidden_Recevoir état et attentions
    vec, _, attentions = self.bert(input_ids, output_attentions=True)
    #Obtenez uniquement le vecteur du premier jeton cls
    vec = vec[:,0,:]
    vec = vec.view(-1, 768)
    #Convertir les dimensions pour la classification dans des couches entièrement connectées
    out = self.linear(vec)
    return F.log_softmax(out), attentions

classifier = BertClassifier()

Paramètres de réglage précis

Je n'ai pas encore effectué de réglage fin, mais comme dans l'article de référence, je désactive tous les paramètres une fois, puis je ne mets à jour que les parties pour lesquelles je souhaite mettre à jour les paramètres. J'ai beaucoup appris. De plus, en ce qui concerne le taux d'apprentissage, la dernière couche de BERT a déjà été pré-apprise, elle ne sera donc que peu mise à jour, et la dernière couche entièrement connectée insérée pour la classification aura un taux d'apprentissage plus élevé. Je vois je vois.

#Paramètres de réglage précis
#Effectuer le calcul du gradient uniquement pour le dernier module BertLayer et l'adaptateur de classification ajouté

#Tout d'abord OFF
for param in classifier.parameters():
    param.requires_grad = False

#Mettre à jour uniquement la dernière couche de BERT ON
for param in classifier.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

#La classification de classe est également activée
for param in classifier.linear.parameters():
    param.requires_grad = True

import torch.optim as optim

#Le taux d'apprentissage est faible pour la partie pré-apprise et élevé pour la dernière couche entièrement connectée.
optimizer = optim.Adam([
    {'params': classifier.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': classifier.linear.parameters(), 'lr': 1e-4}
])

#Paramètres de la fonction de perte
loss_function = nn.NLLLoss()

Apprentissage

Comme dans l'article de référence, il est en fait préférable d'écrire séparément en mode apprentissage et en mode vérification, mais pour le moment, je veux le déplacer, donc je boucle avec seulement le code minimum pour apprendre comme suit. .. La précision finale n'a pas beaucoup changé, que le nombre d'époques soit 5 ou 10, j'ai donc fixé le nombre d'époques à 5. La perte diminue régulièrement, donc ça va pour le moment.

#Paramètres GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Envoyer le réseau au GPU
classifier.to(device)
losses = []

#Le nombre d'époques est de 5
for epoch in range(5):
  all_loss = 0
  for idx, batch in enumerate(train_iter):
    batch_loss = 0
    classifier.zero_grad()
    input_ids = batch.Text[0].to(device)
    label_ids = batch.Label.to(device)
    out, _ = classifier(input_ids)
    batch_loss = loss_function(out, label_ids)
    batch_loss.backward()
    optimizer.step()
    all_loss += batch_loss.item()
  print("epoch", epoch, "\t" , "loss", all_loss)
#epoch 0 	 loss 246.03703904151917
#epoch 1 	 loss 108.01931090652943
#epoch 2 	 loss 80.69403756409883
#epoch 3 	 loss 62.87365382164717
#epoch 4 	 loss 50.78619819134474

Contrôle de précision

Regardons le score F. Le texte de l'article semble dépasser 90%, mais le classement du seul titre aboutit à 85%. Bien que le titre ait une signification sommaire pour l'article, j'ai souvent été intéressé par cette courte phrase jusqu'à 85%.

from sklearn.metrics import classification_report

answer = []
prediction = []
with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score, _ = classifier(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())
print(classification_report(prediction, answer, target_names=categories))
#                precision    recall  f1-score   support
#
#    topic-news       0.80      0.82      0.81       158
#   movie-enter       0.85      0.82      0.83       178
#livedoor-homme       0.68      0.73      0.70       108
#  it-life-hack       0.88      0.82      0.85       179
#dokujo-tsushin       0.82      0.85      0.84       144
#  sports-watch       0.89      0.87      0.88       180
# kaden-channel       0.91      0.97      0.94       180
#        peachy       0.78      0.77      0.78       172
#          smax       0.94      0.91      0.92       177
#
#      accuracy                           0.85      1476
#     macro avg       0.84      0.84      0.84      1476
#  weighted avg       0.85      0.85      0.85      1476

Visualisation de l'attention

Enfin, vérifions la base de jugement de la classification des phrases en visualisant Attention. Le poids Attention à visualiser mettait à jour les paramètres de la dernière couche de BertLayer lors de la configuration du réglage fin, c'est-à-dire que le poids Attention de la dernière couche a été appris pour cette classification de titre, donc le poids Attention de la dernière couche est Il semble qu'il puisse servir de base pour juger de cette tâche.

Le modèle BertClassifer déclaré cette fois renvoie tous les poids d'Attention, donc n'obtenez que la dernière couche comme suit et vérifiez à nouveau la taille.

batch = next(iter(test_iter))
score, attentions = classifier(batch.Text[0].to(device))
#Obtenez uniquement le poids Attention de la dernière couche et vérifiez la taille
print(attentions[-1].size())
# torch.Size([32, 12, 48, 48])

Quand j'ai vérifié à nouveau la Reference, la signification de cette taille était (batch_size, num_heads, sequence_length, sequence_length). Puisque l'attention de BertEncoder est l'auto-attention, combien d'attention est accordée à chaque mot de la deuxième longueur_séquence pour chaque mot de la première longueur_séquence. Cette fois, les phrases ont été classées en utilisant le premier jeton cls, donc en visualisant à quel mot le vecteur du premier jeton est Attention, il semble que cela puisse être considéré comme la base pour juger cette tâche. De plus, l'auto-attention de BERT est de 12 attentions multi-têtes, donc lorsque je la visualise, je vais ajouter les 12 poids d'attention et l'utiliser.

J'ai essayé d'implémenter la partie visualisation comme suit en référence au livre de référence.

def highlight(word, attn):
  html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
  return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(index, batch, preds, attention_weight):
  sentence = batch.Text[0][index]
  label =batch.Label[index].item()
  pred = preds[index].item()

  label_str = id2cat[label]
  pred_str = id2cat[pred]

  html = "Catégorie de réponse correcte: {}<br>Catégorie de prédiction: {}<br>".format(label_str, pred_str)

  #Déclarer zéro tenseur pour la longueur de la phrase
  seq_len = attention_weight.size()[2]
  all_attens = torch.zeros(seq_len).to(device)

  for i in range(12):
    all_attens += attention_weight[index, i, 0, :]

  for word, attn in zip(sentence, all_attens):
    if tokenizer.convert_ids_to_tokens([word.tolist()])[0] == "[SEP]":
      break
    html += highlight(tokenizer.convert_ids_to_tokens([word.numpy().tolist()])[0], attn)
  html += "<br><br>"
  return html

batch = next(iter(test_iter))
score, attentions = classifier(batch.Text[0].to(device))
_, pred = torch.max(score, 1)

from IPython.display import display, HTML
for i in range(BATCH_SIZE):
  html_output = mk_html(i, batch, pred, attentions[-1])
  display(HTML(html_output))

Voici quelques résultats de visualisation.

――Le magasin Yodobashi Camera Umeda est divisé en sous-mots, mais il est lié aux appareils ménagers, donc c'est une attention partielle mais. image.png

--Il est intéressant de le juger comme un canal kaden basé sur Takahashi Meijin (une personne qui frappe rapidement) image.png

--peachy (articles sur l'amour des femmes). C'est aussi sympa. image.png

――Avez-vous été traîné par Peachy dans le vrai discours? image.png

J'ai principalement présenté les bons, mais honnêtement, j'ai trouvé que c'était une attention délicate dans l'ensemble. (Je suis inquiet si la mise en œuvre est vraiment correcte ...) Cependant, j'étais intéressé par le fait qu'il est étonnant de prêter attention aux parties qui ne sont pas si bonnes même si elles sont divisées en sous-mots.

en conclusion

Grâce à huggingface / transformers et articles de référence, je suis capable de déplacer BERT, quoique d'une manière ou d'une autre. Je souhaite utiliser BERT pour diverses tâches

fin

Recommended Posts

J'ai essayé d'implémenter la classification des phrases et la visualisation de l'attention par le japonais BERT avec PyTorch
J'ai essayé d'implémenter la classification des phrases par Self Attention avec PyTorch
J'ai essayé de comparer la précision de la classification des phrases BERT japonaises et japonaises Distil BERT avec PyTorch et introduction de la technique d'amélioration de la précision BERT
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter la permutation en Python
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé d'implémenter ADALINE en Python
J'ai essayé d'implémenter PPO en Python
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'implémenter la régression linéaire bayésienne par échantillonnage de Gibbs en python
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
[PyTorch] Introduction à la classification des documents japonais à l'aide de BERT
J'ai essayé d'implémenter le tri sélectif en python
J'ai essayé d'implémenter un pseudo pachislot en Python
J'ai essayé d'implémenter le poker de Drakue en Python
J'ai essayé d'implémenter GA (algorithme génétique) en Python
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'implémenter PCANet
J'ai essayé d'implémenter StarGAN (1)
[Django] J'ai essayé d'implémenter des restrictions d'accès par héritage de classe.
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter la fonction d'envoi de courrier en Python
J'ai essayé d'implémenter le blackjack du jeu Trump en Python
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
J'ai essayé d'implémenter Deep VQE
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé d'implémenter Realness GAN
J'ai essayé de mettre en œuvre un jeu de dilemme de prisonnier mal compris en Python
J'ai essayé d'implémenter Autoencoder avec TensorFlow
[PyTorch] Introduction à la classification de documents à l'aide de BERT
J'ai essayé d'implémenter le jeu de cartes de Trump en Python
J'ai essayé de résumer tous les outils de visualisation Python utilisés dans la recherche par des étudiants diplômés en sciences actifs [Application]
J'ai essayé d'implémenter le tri par fusion en Python avec le moins de lignes possible
[PyTorch] Comment utiliser BERT - Réglage fin des modèles pré-entraînés japonais pour résoudre les problèmes de classification
J'ai essayé de prédire l'évolution de la quantité de neige pendant 2 ans par apprentissage automatique
J'ai essayé d'implémenter ce qui semble être un outil de snipper Windows avec Python
J'ai essayé de programmer la bulle de tri par langue
J'ai essayé d'obtenir une image en grattant
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé de classer les boules de dragon par adaline
J'ai essayé de mettre en œuvre le problème du voyageur de commerce
[Keras] J'ai essayé de résoudre le problème de classification des zones de type beignet par apprentissage automatique [Étude]
[Introduction] J'ai essayé de l'implémenter moi-même tout en expliquant l'arbre de dichotomie
[Série pour les gens occupés] J'ai essayé de résumer avec une analyse de syntaxe pour appeler les actualités en 30 secondes
[Introduction] J'ai essayé de l'implémenter moi-même tout en expliquant pour comprendre la dichotomie
[Explication de la mise en œuvre] Comment utiliser la version japonaise de BERT dans Google Colaboratory (PyTorch)
J'ai essayé de représenter graphiquement les packages installés en Python
J'ai essayé de mettre en œuvre la gestion des processus statistiques multivariés (MSPC)
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé d'implémenter le perceptron artificiel avec python
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé de résumer comment utiliser les pandas de python
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter le calcul automatique de la preuve de séquence
Django super introduction par les débutants Python! Partie 6 J'ai essayé d'implémenter la fonction de connexion