When my colleague was learning tens of thousands of Epoch on Google Colaboratory using Keras (tf.keras), he lamented that the browser became heavy and the display did not update after all.
The cause was that the progress log was displayed by specifying verbose = 1
in model.fit
, but the log became bloated and heavy, and when a certain threshold (?) Was reached, the display was updated. It's gone. (Operation is continuing)
It's enough to stop logging at verbose = 0
and turn it around, but that makes it impossible to check the progress.
By the way, I remembered that I had troubled the same phenomenon in the past and solved it by using the Callback function, so I would like to share it.
You can customize the behavior during learning by specifying a class that inherits the tf.keras.callbacks.Callback
class in the argument callbacks
of model.fit
.
Check the official documentation for details.
【tf.keras.callbacks.Callback - TensorFlow】
[Callback --Keras Documentation]
tf.keras.callbacks.Callback
provides several methods, but they are supposed to be called at some point.
By overriding these methods, you can change the learning behavior.
This time I overridden the following method.
Method | Timing to be called |
---|---|
on_train_begin | At the beginning of learning |
on_train_end | At the end of learning |
on_batch_begin | At the start of Batch |
on_batch_end | At the end of Batch |
on_epoch_begin | At the start of Epoch |
on_epoch_end | At the end of Epoch |
In addition to the above, there are methods that are called during inference and testing.
By continuing to overwrite the progress display during learning on the same line without line breaks, the output cell will not grow and overflow. Use the following code to keep overwriting on the same line.
print('\rTest Print', end='')
The \ r
in the above code means Carriage Return (CR), which allows you to move the cursor to the beginning of a line.
This allows you to overwrite the displayed lines.
However, if this is left as it is, a line break will occur every time the print statement is executed.
Therefore, specify ʻend ='' as an argument of the print statement. In short, line breaks are suppressed by specifying that the first argument should not be output after output. By default, ʻend ='\ n'
is specified in the print statement.
\ n
stands for Line Feed (LF), which moves the cursor to a new line (that is, a newline).
If you execute the following code as a trial, it will continue to overwrite 0 to 9 and can be expressed as if it is counting up.
Overwrite sample
from time import sleep
for i in range(10):
print('\r%d' % i, end='')
sleep(1)
I think here.
I also feel that it is better to set ʻend ='\ r'instead of printing
'\ r'`.
However, this attempt does not work.
Because in Python, when '\ r'
is output, it seems that the contents output so far are cleared.
For example, if you execute print ('Test Print', end ='\ r')
, nothing will be displayed, which is inconvenient for this purpose.
Therefore, there is no choice but to output the character string you want to output after outputting '\ r'
just before the character output.
So, using the above method, code with the following policy.
Shows the start / end and the time it was executed. Line breaks are normal here.
The number of Epoch, the number of processed data, acc and loss are displayed. This display is overwritten without line breaks to reduce the size of the output cell.
We will implement it based on the above policy. The model part is based on the TensorFlow tutorial. 【TensorFlow 2 quickstart for beginners】
import tensorflow as tf
#Callback function definition for custom progress display
"""
Callback function for displaying progress.
Data is collected and displayed at the end of Batch and Epoch.
The point is when printing is output/Returning the cursor to the beginning of the line with r, the argument end=''The point is that line breaks are suppressed.
"""
import datetime
class DisplayCallBack(tf.keras.callbacks.Callback):
#constructor
def __init__(self):
self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
self.now_batch, self.now_epoch = None, None
self.epochs, self.samples, self.batch_size = None, None, None
#Custom progress display(Display body)
def print_progress(self):
epoch = self.now_epoch
batch = self.now_batch
epochs = self.epochs
samples = self.samples
batch_size = self.batch_size
sample = batch_size*(batch)
# '\r'And end=''To avoid line breaks using
if self.last_val_acc and self.last_val_loss:
# val_acc/val_loss can be displayed
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
else:
# val_acc/val_loss cannot be displayed
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')
#At the start of fit
def on_train_begin(self, logs={}):
print('\n##### Train Start ##### ' + str(datetime.datetime.now()))
#Get parameters
self.epochs = self.params['epochs']
self.samples = self.params['samples']
self.batch_size = self.params['batch_size']
#Avoid standard progress display
self.params['verbose'] = 0
#At the start of batch
def on_batch_begin(self, batch, logs={}):
self.now_batch = batch
#When batch is completed(Progress display)
def on_batch_end(self, batch, logs={}):
#Update of the latest information
self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
self.last_loss = logs.get('loss') if logs.get('loss') else 0.0
#Progress display
self.print_progress()
#At the start of epoch
def on_epoch_begin(self, epoch, log={}):
self.now_epoch = epoch
#When epoch is completed(Progress display)
def on_epoch_end(self, epoch, logs={}):
#Update of the latest information
self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0
#Progress display
self.print_progress()
#When fit is completed
def on_train_end(self, logs={}):
print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
#Instantiation for callback function
cbDisplay = DisplayCallBack()
#Read and normalize MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Building a Sequential model
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#Model learning
#Here we use the callback function
history = model.fit(x_train, y_train,
validation_data = (x_test, y_test),
batch_size=128,
epochs=5,
verbose=1, #The standard progress display is ignored in the callback function
callbacks=[cbDisplay]) #Set custom progress display as callback function
#Model evaluation
import pandas as pd
results = pd.DataFrame(history.history)
results.plot();
If you execute the above, no matter how many Epoch you turn, only the following 3 lines will be displayed. The second line is rewritten with the latest information at the end of Batch and Epoch, and the last line is output when learning is completed.
##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442
Recommended Posts