Open In App

Save and load models in Tensorflow

Last Updated : 08 Apr, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Training machine learning or deep learning model is time-consuming and shutting down the notebook causes all the weights and activations to disappear as the memory is flushed. Hence, we save models for reusability, collaboration, and continuation of training.

  • Saving the model allows us to avoid lengthy training periods and enables others to replicate the model.
  • It also allows you to share the model with others so they can replicate your results.
  • When sharing machine learning models it’s common to include the following:
    • code to create the model
    • trained weights for the model

Below are the methods for saving and loading machine learning models in TensorFlow.

Methods to Save and Load Models

Here are the methods that can be used to save model.

1. Using the save() Method

The save() method allows you to save the complete model including:

  • Model architecture
  • Model weights
  • Model optimizer state to resume training from where you left off.

tensorflow.keras.X.save('location/model_name')

Where X can be a Sequential, Functional Model or Model subclass. The location specifies where the model is stored and if no path is specified it will be saved in the same location as the Python file.

To load the model use theload_model() method:

tensorflow.keras.models.load_model('location/model_name')

2. Using the save_weights() Method

In some cases you might want to save just the weights of the model instead of the entire model. This can be done using the save_weights()method which saves the weights of all the layers in the model.

tensorflow.keras.Model.save_weights('location/weights_name')

The weights_nameis the file name for the saved weights and if no path is provided it is saved in the same location as the Python file.

To load the saved weights use the load_weights() method:

tensorflow.keras.Model.load_weights('location/weights_name')

Note: When loading weights ensure that the model's architecture is the same as the one used to save the weights. For example you cannot load the weights of a model with two dense layers into a model with just one dense layer.

3. HDF5 Format (.h5)

If you save your model with the .h5 extension the model is saved in HDF5 format. This format is portable and commonly used for storing large data and models. You can specify the .h5 extension when saving the model and TensorFlow will automatically save the model in this format.

model.save('my_model.h5')

If you don’t specify the extension, TensorFlow saves the model in its native format.

Code to Save and Load Models

Here we will build a neural network and then save it.

1. Import Necessary Module

We will importing tenserflow for model making.

  • Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D and BatchNormalization is imported to build neural networks.
  • Modelto define the model architecture.
  • load_modelto load saved models.
Python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model, load_model

2. Load and Preprocess Data

  • Here we will use CIFAR-10 dataset which contains 60,000 images with 32x32 size in 10 classes and 50,000 images will be used for training and 10,000 for testing.
Python
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

y_train, y_test = y_train.flatten(), y_test.flatten()

3. Defining the Model

The model contains following layers:

Python
K = len(set(y_train))  
i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
x = BatchNormalization()(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Flatten()(x)
x = Dropout(0.2)(x)

x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)

x = Dense(K, activation='softmax')(x)

model = Model(i, x)
model.summary()

Output:

Screenshot-2025-03-27-101006
Model Summary

4. Saving and Loading the Model

Python
model.save('my_model.h5')
print("Model saved!")

saved_model = load_model('my_model.h5')
if saved_model is not None:
    print("Model loaded successfully!")
else:
    print("Failed to load the model.")

Output:

Model saved!
Model loaded successfully!

Saving and loading models is essential for efficient machine learning workflows, enabling you to resume training without starting from scratch and share models with others.


Next Article

Similar Reads