Le flux d'implémentation pour classer les documents à l'aide de torchtext est expliqué avec Official tutorial. De plus, le Google Colabolatry qui accompagne le tutoriel officiel donne une erreur. Je posterai le code après avoir corrigé la partie qui est. Enfin, j'expliquerai le code source de torchtext.datasets.text_classification.
Google Colabolatry
Termes de base pour le traitement du langage naturel tels que N-gramme
Lors de la classification de documents à l'aide de torchtext, le flux de mise en œuvre est le suivant. Nous examinerons le code dans la section suivante, donc cette théorie ne donne qu'un aperçu.
Vérifions le flux ci-dessus avec le code dans le didacticiel.
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline
Si vous l'exécutez tel quel, l'erreur suivante se produira lors de l'importation du module décrit plus loin.
from torchtext.datasets import text_classification
ImportError: cannot import name 'text_classification'
Le code correct ressemble à ceci: En outre, il peut être nécessaire d'initialiser le runtime en raison de changements dans la version de torchtext. Dans ce cas, exécutez simplement le redémarrage et réexécutez les cellules de haut en bas (vous n'avez pas besoin d'appuyer sur le redémarrage après la deuxième installation de pip).
!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline
La cause est la version de torchtext. Si vous installez pip sans rien spécifier, la version 0.3.1 sera installée. Étant donné que text_classification est implémentée dans la version 0.4 ou ultérieure, elle ne peut pas être utilisée telle quelle dans la version 0.3. Dans ce qui précède, il est fixé à 0,5, mais s'il est égal ou supérieur à 0,4, il n'y a pas de problème.
import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Il a un flux simple d'intégration → linéaire. Dans init_weight, les poids sont initialisés avec les poids générés à partir de la distribution uniforme.
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
def generate_batch(batch):
label = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets = [0] + [len(entry) for entry in text]
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text = torch.cat(text)
return text, offsets, label
from torch.utils.data import DataLoader
def train_func(sub_train_):
# Train the model
train_loss = 0
train_acc = 0
data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=generate_batch)
for i, (text, offsets, cls) in enumerate(data):
optimizer.zero_grad()
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss = criterion(output, cls)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_acc += (output.argmax(1) == cls).sum().item()
# Adjust the learning rate
scheduler.step()
return train_loss / len(sub_train_), train_acc / len(sub_train_)
def test(data_):
loss = 0
acc = 0
data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
for text, offsets, cls in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
with torch.no_grad():
output = model(text, offsets)
loss = criterion(output, cls)
loss += loss.item()
acc += (output.argmax(1) == cls).sum().item()
return loss / len(data_), acc / len(data_)
Si vous apprenez correctement, vous pouvez obtenir une précision de 0,9 ou plus.
import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
random_split(train_dataset, [train_len, len(train_dataset) - train_len])
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train_func(sub_train_)
valid_loss, valid_acc = test(sub_valid_)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
Dans TORCHTEXT.DATASETS.TEXT_CLASSIFICATION, le traitement est effectué pour fournir littéralement les données nécessaires. Au contraire, aucune autre opération n'est effectuée. En d'autres termes, l'objectif de ce module est de formater les données nécessaires à la formation pour divers ensembles de données. Par conséquent, cette fois, nous nous concentrerons sur le flux de fourniture d'ensembles de données de train et de test. Le code source décrit dans la description suivante est ici. Tout d'abord, je republierai le code suivant.
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Ici, vous pouvez voir qu'un répertoire appelé .data est créé et que ce répertoire est utilisé comme racine pour générer des ensembles de données de train et de test. Cependant, cela seul a divers points peu clairs, y compris les données. Alors, lisons le code et voyons un traitement plus spécifique.
Certaines données sont fournies pour la classification des documents. Les données actuellement fournies sont les suivantes.
Si vous souhaitez obtenir directement chaque donnée, vous pouvez la télécharger à partir de l'url décrite dans la variable URLS.
URLS = {
'AG_NEWS':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
'SogouNews':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
'DBpedia':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
'YelpReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
'YelpReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
'YahooAnswers':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
'AmazonReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
'AmazonReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}
Maintenant, suivons réellement le traitement des données à travers le code source. La première chose à faire est la définition de la fonction.
def AG_NEWS(*args, **kwargs):
""" Defines AG_NEWS datasets.
The labels includes:
- 1 : World
- 2 : Sports
- 3 : Business
- 4 : Sci/Tech
Create supervised learning dataset: AG_NEWS
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
ngrams: a contiguous sequence of n items from s string text.
Default: 1
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
include_unk: include unknown token in the data (Default: False)
Examples:
>>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)
"""
return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)
Vous pouvez voir que les données formatées sont renvoyées à l'aide de la fonction _setup_datasets. Désormais, seul AG_NEWS est ciblé, mais le même traitement est effectué pour les autres ensembles de données. Ensuite, enregistrez la fonction définie dans la variable DATASETS au format dict.
DATASETS = {
'AG_NEWS': AG_NEWS,
'SogouNews': SogouNews,
'DBpedia': DBpedia,
'YelpReviewPolarity': YelpReviewPolarity,
'YelpReviewFull': YelpReviewFull,
'YahooAnswers': YahooAnswers,
'AmazonReviewPolarity': AmazonReviewPolarity,
'AmazonReviewFull': AmazonReviewFull
}
En outre, la variable LABELS stocke les informations d'étiquette pour chaque ensemble de données au format dict.
LABELS = {
'AG_NEWS': {1: 'World',
2: 'Sports',
3: 'Business',
4: 'Sci/Tech'},
}
Bien que omis ici, les étiquettes autres que AG_NEWS sont stockées dans le même format. Puisque la fonction est enregistrée au format dict avec la variable DATASETS ci-dessus, les deux suivantes font référence à la même chose.
text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS
Vérifiez le traitement des données en regardant la fonction _setup_datasets.
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)
for fname in extracted_files:
if fname.endswith('train.csv'):
train_csv_path = fname
if fname.endswith('test.csv'):
test_csv_path = fname
if vocab is None:
logging.info('Building Vocab based on {}'.format(train_csv_path))
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
logging.info('Vocab has {} entries'.format(len(vocab)))
logging.info('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
logging.info('Creating testing data')
test_data, test_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels don't match")
return (TextClassificationDataset(vocab, train_data, train_labels),
TextClassificationDataset(vocab, test_data, test_labels))
Le traitement principal est le suivant.
class TextClassificationDataset(torch.utils.data.Dataset):
"""Defines an abstract text classification datasets.
Currently, we only support the following datasets:
- AG_NEWS
- SogouNews
- DBpedia
- YelpReviewPolarity
- YelpReviewFull
- YahooAnswers
- AmazonReviewPolarity
- AmazonReviewFull
"""
[docs] def __init__(self, vocab, data, labels):
"""Initiate text-classification dataset.
Arguments:
vocab: Vocabulary object used for dataset.
data: a list of label/tokens tuple. tokens are a tensor after
numericalizing the string tokens. label is an integer.
[(label1, tokens1), (label2, tokens2), (label2, tokens3)]
label: a set of the labels.
{label1, label2}
Examples:
See the examples in examples/text_classification/
"""
super(TextClassificationDataset, self).__init__()
self._data = data
self._labels = labels
self._vocab = vocab
def __getitem__(self, i):
return self._data[i]
def __len__(self):
return len(self._data)
def __iter__(self):
for x in self._data:
yield x
def get_labels(self):
return self._labels
def get_vocab(self):
return self._vocab
Vous pouvez voir qu'il s'agit d'une classe pour récupérer chaque donnée, pas pour traiter de nouvelles données. Comme vous pouvez le voir à partir de la fonction _setup_datasets et de la classe TextClassificationDataset, l'ensemble de données est converti en N-gramme et en état stocké plutôt que dans le document brut. Par conséquent, si vous souhaitez utiliser un format de données autre que N-gramme, vous devez écrire votre propre traitement en fonction des données enregistrées dans .data ou des données téléchargées à partir de l'url décrite dans URLS.
Les informations difficiles à comprendre par simple impression peuvent être comprises en traçant le code source. Je souhaite continuer à lire le code source et à compiler des informations.
Recommended Posts