Remarques sur l'utilisation de tf.data

Résumez ce que vous avez appris sur tf.data pour gérer les fichiers volumineux

J'ai renvoyé au site suivant https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3 https://qiita.com/wasnot/items/9b64550237a3c5267bfd https://qiita.com/everylittle/items/a7c31b08d2f76c886a92

Qu'est-ce que tf.data

Une bibliothèque pour la fourniture de données tensorflow. Il semble avoir les avantages suivants lorsqu'il est utilisé.

  1. Vous pouvez réduire la latence du GPU et maximiser la vitesse d'apprentissage
  2. Les données qui ne rentrent pas dans la mémoire peuvent être lues séquentiellement
  3. Le prétraitement tel que l’augumentation des données peut être accéléré
  4. Mettre en place un pipeline de prétraitement

1. Convertissez et utilisez numpy.array

Convertir

Vous pouvez l'utiliser en convertissant de np.array en objet tf.data

python


import numpy as np
import tensorflow as tf

arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

repeat

Sortie en répétant les heures des arguments

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

batch

Les arguments sont groupés et sortis

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

shuffle

L'argument spécifie dans quelle mesure les données doivent être remplacées. Si l'argument est 1, il n'y aura pas de remplacement, et s'il s'agit d'une petite valeur, il ne sera pas suffisamment mélangé, donc je pense qu'il est préférable d'entrer la même valeur que la taille des données.

Cliquez ici pour plus de détails sur la taille de la lecture aléatoire https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

combinaison

Vous pouvez utiliser les éléments ci-dessus en combinaison. Il sera exécuté dans l'ordre, alors faites attention à ne pas faire la chose insignifiante de couper le lot, puis de le mélanger.

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)

for item in dataset:
    print(item)
    print()

output


tf.Tensor(
[[15 16 17 18 19]
 [ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[20 21 22 23 24]
 [ 0  1  2  3  4]
 [20 21 22 23 24]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[15 16 17 18 19]
 [ 5  6  7  8  9]], shape=(2, 5), dtype=int32)

argumentation

Vous pouvez appliquer la fonction avec dataset.map ().

Il est souhaitable que la fonction à appliquer soit composée de la fonction tensorflow, mais il semble qu'il soit également possible de convertir une fonction normalement écrite avec @ tf.function et tf.py_function.

python


import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

def rotate(image):
    return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)

@tf.function
def rotate_tf(image):
    rotated = tf.py_function(rotate,[image],[tf.int32])
    return rotated[0]

[train_x, train_y], [test_x, test_y] =  tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)

first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))

plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
    plt.subplot(4, 4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(image)
    plt.grid(False)
plt.show()

image.png

Mettre x et y ensemble dans un ensemble de données

Vous pouvez également combiner plusieurs données en un seul ensemble de données

python


def make_model():
    tf.keras.backend.clear_session()

    inputs = tf.keras.layers.Input(shape=(28, 28))
    network = tf.keras.layers.Flatten()(inputs)
    network = tf.keras.layers.Dense(100, activation='relu')(network)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(network)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    model.summary()
    return model

[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)

model = make_model()
hist = model.fit(train_data, validation_data=test_data,
                 epochs=10, verbose=False)

plt.figure(figsize=(4,4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

2. Sérialiser et utiliser

Afin de lire les données efficacement, il est recommandé de sérialiser les données et de les enregistrer sous la forme d'un ensemble de fichiers de 100 à 200 Mo pouvant être lus en continu. Vous pouvez facilement le faire avec TFRecord.

sauvegarder

Exporter le fichier TFRecord avec tf.io.TFRecordWriter ()

python


[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))
    }))

def write_tfrecord(images, labels, filename):
    writer = tf.io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        ex = make_example(image.ravel().tolist(), [int(label)])
        writer.write(ex.SerializeToString())
    writer.close()

write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')

image.png La taille du fichier semble être légèrement supérieure à npz.

Lire (1 enregistrement à la fois)

Lire avec tf.data.TFRecordDataset ()

Les données lues sont sérialisées et doivent être analysées. Dans l'exemple ci-dessous, tf.io.parse_single_example () est utilisé pour l'analyse. Si vous l'appelez avec la même clé que lorsque vous l'avez écrite et remodelée, ce sera la même chose que tf.data avant la sérialisation.

python


def parse_features(example):
    features = tf.io.parse_single_example(example, features={
        'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
        'y' : tf.io.FixedLenFeature([1], tf.int64),
    })
    x = features['x']
    y = features['y']
    return x, y

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

Lire (unité de lot)

En fait, il est plus rapide d'analyser par lot que d'analyser un enregistrement à la fois avec tf.io.parse_single_example (), il est donc recommandé d'analyser par lot.

python


def dict2tuple(feat):
    return feat["x"], feat["y"]

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

temps de traitement

J'ai mesuré le temps de traitement lorsque les données mnist ont été entraînées avec le même modèle. Après tout, il semble que l'analyse d'un enregistrement à la fois sera assez lente. Si vous traitez par lots, vous pouvez obtenir la même vitesse que tf.data en mémoire, on peut donc dire que c'est assez rapide.

De plus, ici, le résultat est que numpy.array est le plus rapide en l'état, mais en pratique tf.data est évidemment plus rapide, donc numpy.array n'est pas plus rapide s'il est en mémoire. Je pense. Nous vous serions reconnaissants si vous pouviez essayer différentes choses dans votre propre environnement.

Recommended Posts

Remarques sur l'utilisation de tf.data
Comment utiliser xml.etree.ElementTree
Comment utiliser virtualenv
Comment utiliser Seaboan
Comment utiliser la correspondance d'image
Comment utiliser le shogun
Comment utiliser Pandas 2
Comment utiliser Virtualenv
Comment utiliser numpy.vectorize
Comment utiliser pytest_report_header
Comment utiliser partiel
Comment utiliser Bio.Phylo
Comment utiliser SymPy
Comment utiliser x-means
Comment utiliser WikiExtractor.py
Comment utiliser IPython
Comment utiliser virtualenv
Comment utiliser Matplotlib
Comment utiliser iptables
Comment utiliser numpy
Comment utiliser TokyoTechFes2015
Comment utiliser venv
Comment utiliser le dictionnaire {}
Comment utiliser Pyenv
Comment utiliser la liste []
Comment utiliser python-kabusapi
Comment utiliser OptParse
Comment utiliser le retour
Comment utiliser pyenv-virtualenv
Comment utiliser imutils
Comment utiliser Qt Designer
Comment utiliser la recherche triée
[gensim] Comment utiliser Doc2Vec
python3: Comment utiliser la bouteille (2)
Comprendre comment utiliser django-filter
Comment utiliser le générateur
[Python] Comment utiliser la liste 1
Comment utiliser FastAPI ③ OpenAPI
Comment utiliser Python Argparse
Comment utiliser IPython Notebook
Comment utiliser Pandas Rolling
[Note] Comment utiliser virtualenv
Comment utiliser les dictionnaires redis-py
Python: comment utiliser pydub
[Python] Comment utiliser checkio
[Aller] Comment utiliser "... (3 périodes)"
Comment faire fonctionner GeoIp2 de Django
[Python] Comment utiliser input ()
Comment utiliser le décorateur
[Introduction] Comment utiliser open3d
Comment utiliser Python lambda
Comment utiliser Jupyter Notebook
[Python] Comment utiliser virtualenv
python3: Comment utiliser la bouteille (3)
python3: Comment utiliser la bouteille
Comment utiliser Google Colaboratory
Comment utiliser les octets Python
Comment utiliser cron (mémo personnel)