Better performance with the tf.data API I designed the data loader for CNN by referring to the above page, and tuned the speed of the data loader in Tensorflow. In conclusion, I've tried some acceleration techniques, but unfortunately I couldn't get faster than the baseline implementation.
tf.data Tensorflow provides an API for the input pipeline called tf.data. When loading data that does not fit in RAM, such as an image file, into the model, tf.data can be used to achieve high-speed processing because data preprocessing and NN learning are performed in parallel internally. The rough mechanism is as follows.
If you implement it with python generator etc., it will be inefficient because the other will be idle while the CPU or GPU is running, but if you implement it with tf.data, you can shorten the idle time. Will be.
The implementation is explained at here, and it can be implemented with relatively little effort.
ImageNet images are saved in jpeg and loaded into mobilenet. The experimental environment is Google Colaboratory.
Introduction Basic usage of tf.data. Read the paths of the saved images one by one and randomly crop them into a 244x244 image.
train_img_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
train_img_paths.sort()
num_train_imgs = len(train_img_paths)
train_label = [1 for path in train_img_paths]
m = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4", output_shape=[1], trainable=True)
])
m.build([None, IMAGE_SIZE[0], IMAGE_SIZE[1], IMAGE_SIZE[2]])
m.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer='Adam')
def preprocessing(img_path, label):
img = tf.image.decode_image(tf.io.read_file(img_path))
img = tf.image.random_crop(img, size=IMAGE_SIZE)
img = tf.cast(img, tf.float32)
img = img / 255.0
label = tf.cast(label, tf.float32)
img.set_shape(IMAGE_SIZE)
return img, label
train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).map(preprocessing).repeat().batch(batch_size).prefetch(buffer_size=AUTOTUNE)
time_start = time.time()
m.fit(train_data, epochs=epochs, steps_per_epoch=steps_per_epoch)
time_end = time.time()
print(f'Total time:{(time_end-time_start)/60.0:.3f}[min]')
print(f'Time per step:{(time_end-time_start)/steps_per_epoch*epochs:.3f} [sec]')
Total time:0.446[min]
Time per step:0.803 [sec]
It took about 0.8 seconds per step. From here, I will devise ways to speed up learning.
Run the map function of the Dataset object in parallel. It should be fast because the data extraction part is processed in multiple processes.
In the previous section
train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).map(preprocessing).repeat().batch(batch_size).prefetch(buffer_size=AUTOTUNE)
Rewrite the part of as follows.
train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).repeat().map(preprocessing, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(buffer_size=AUTOTUNE)
Total time:3.726[min]
Time per step:6.707 [sec]
It's been late for some reason. Is it a specification of Google Colaboratory? (Investigation required)
Caching is a function that temporarily holds the read data in RAM or storage.
train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).repeat().map(preprocessing, num_parallel_calls=AUTOTUNE).batch(batch_size).cache()
Total time:7.014[min]
Time per step:12.625 [sec]
Once again, we couldn't speed it up. I think the cause is that the map function is designed to read the image and convert the image data at the same time. You need a structure that separates image reading and image data conversion. (Future tasks)
The user-defined map function seems to incur overhead for processing reasons. Therefore, it seems that it will be faster if the user-defined map function is vectorized, that is, the input is processed at once. Specifically, it is recommended to implement in batch processing → data conversion instead of data conversion → batch processing.
I haven't experimented yet due to time constraints, but the URL experiment shown at the beginning of the article is up to 30 times faster.
I experimented with parallelization and caching of the map function, but none of them led to speedup.
I think there are multiple causes, so it is necessary to investigate in the future. If you have any advice, I would be grateful if you could let me know.
Recommended Posts