ML Zoomcamp 2023 – Deep Learning – Part 9

  1. Checkpointing
    1. Saving the best model only
    2. Training a model with callbacks

Checkpointing

Checkpointing is a way of saving our model after each iteration or when certain conditions are met, f.e. when the model achieves the best performance so far. This is a nice way because when the model starts to oscillate, the model after 10 iterations is not necessarily the best one.

Saving the best model only

How can we do this? After each epoch we trained we can evaluate the performance of the model on validation dataset. This we do for every epoch, then we look at the numbers and can invoke a callback. With this callback we can do anything we want. The evaluation on validation data is kind of callback. The history with all information is also kind of such a callback. This is just something we invoke after each epoch finishes

Training a model with callbacks

model.save_weights('model_v1.h5', save_format='h5')

# Keras uses this template for saving files.
'xception_v1_{epoch:02d}_{val_accuracy:.3f}.h5'.format(epoch=3, val_accuracy=0.84)

# Output: 'xception_v1_03_0.840.h5'

save_best_only=True to save only when it’s an improvement regarding the last best result. mode='max' because we want to have a maximized accuracy, if we would use a loss value we should take mode='min'.

checkpoint = keras.callbacks.ModelCheckpoint(
    'xception_v1_{epoch:02d}_{val_accuracy:.3f}.h5',
    save_best_only=True,
    monitor='val_accuracy',
    mode='max'
)

Now we can use this defined callback and retrain the best model using this callback.

learning_rate = 0.001

model = make_model(learning_rate=learning_rate)

history = model.fit(
    train_ds,
    epochs=10,
    validation_data=val_ds,
    callbacks=[checkpoint]
)

# Output:
# Epoch 1/10
# 96/96 [==============================] - ETA: 0s - loss: 1.1087 - accuracy: 0.6222
# 96/96 [==============================] - 130s 1s/step - loss: 1.1087 - accuracy: 0.6222 - val_loss: 0.7088 - val_accuracy: 0.7947
# Epoch 2/10
# 96/96 [==============================] - 129s 1s/step - loss: 0.6355 - accuracy: 0.7816 - val_loss: 0.6039 - val_accuracy: 0.8240
# Epoch 3/10
# 96/96 [==============================] - 118s 1s/step - loss: 0.5099 - accuracy: 0.8279 - val_loss: 0.5655 - val_accuracy: 0.8152
# Epoch 4/10
# 96/96 [==============================] - 129s 1s/step - loss: 0.4284 - accuracy: 0.8647 - val_loss: 0.5607 - val_accuracy: 0.8094
# Epoch 5/10
# 96/96 [==============================] - 134s 1s/step - loss: 0.3777 - accuracy: 0.8814 - val_loss: 0.5315 - val_accuracy: 0.8270
# Epoch 6/10
# 96/96 [==============================] - 122s 1s/step - loss: 0.3263 - accuracy: 0.9055 - val_loss: 0.5289 - val_accuracy: 0.8328
# Epoch 7/10
# 96/96 [==============================] - 119s 1s/step - loss: 0.2888 - accuracy: 0.9201 - val_loss: 0.5177 - val_accuracy: 0.8328
# Epoch 8/10
# 96/96 [==============================] - 118s 1s/step - loss: 0.2590 - accuracy: 0.9345 - val_loss: 0.5317 - val_accuracy: 0.8299
# Epoch 9/10
# 96/96 [==============================] - 112s 1s/step - loss: 0.2324 - accuracy: 0.9420 - val_loss: 0.5248 - val_accuracy: 0.8387
# Epoch 10/10
# 96/96 [==============================] - 111s 1s/step - loss: 0.2121 - accuracy: 0.9527 - val_loss: 0.5264 - val_accuracy: 0.8270

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.