Save and load models in Tensorflow
Last Updated :
08 Apr, 2025
Training machine learning or deep learning model is time-consuming and shutting down the notebook causes all the weights and activations to disappear as the memory is flushed. Hence, we save models for reusability, collaboration, and continuation of training.
- Saving the model allows us to avoid lengthy training periods and enables others to replicate the model.
- It also allows you to share the model with others so they can replicate your results.
- When sharing machine learning models it’s common to include the following:
- code to create the model
- trained weights for the model
Below are the methods for saving and loading machine learning models in TensorFlow.

Methods to Save and Load Models
Here are the methods that can be used to save model.
1. Using the save()
Method
The save()
method allows you to save the complete model including:
- Model architecture
- Model weights
- Model optimizer state to resume training from where you left off.
tensorflow.keras.X.save('location/model_name')
Where X
can be a Sequential, Functional Model or Model subclass. The location
specifies where the model is stored and if no path is specified it will be saved in the same location as the Python file.
To load the model use theload_model()
method:
tensorflow.keras.models.load_model('location/model_name')
2. Using the save_weights()
Method
In some cases you might want to save just the weights of the model instead of the entire model. This can be done using the save_weights()
method which saves the weights of all the layers in the model.
tensorflow.keras.Model.save_weights('location/weights_name')
The weights_name
is the file name for the saved weights and if no path is provided it is saved in the same location as the Python file.
To load the saved weights use the load_weights
()
method:
tensorflow.keras.Model.load_weights('location/weights_name')
Note: When loading weights ensure that the model's architecture is the same as the one used to save the weights. For example you cannot load the weights of a model with two dense layers into a model with just one dense layer.
3. HDF5 Format (.h5
)
If you save your model with the .h5
extension the model is saved in HDF5 format. This format is portable and commonly used for storing large data and models. You can specify the .h5
extension when saving the model and TensorFlow will automatically save the model in this format.
model.save('my_model.h5')
If you don’t specify the extension, TensorFlow saves the model in its native format.
Code to Save and Load Models
Here we will build a neural network and then save it.
1. Import Necessary Module
We will importing tenserflow for model making.
Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D
and
BatchNormalization
is imported
to build neural networks.Model
to define the model architecture.load_model
to load saved models.
Python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model, load_model
2. Load and Preprocess Data
- Here we will use CIFAR-10 dataset which contains 60,000 images with 32x32 size in 10 classes and 50,000 images will be used for training and 10,000 for testing.
Python
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.flatten(), y_test.flatten()
3. Defining the Model
The model contains following layers:
Python
K = len(set(y_train))
i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
x = BatchNormalization()(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dropout(0.2)(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)
x = Dense(K, activation='softmax')(x)
model = Model(i, x)
model.summary()
Output:
Model Summary4. Saving and Loading the Model
Python
model.save('my_model.h5')
print("Model saved!")
saved_model = load_model('my_model.h5')
if saved_model is not None:
print("Model loaded successfully!")
else:
print("Failed to load the model.")
Output:
Model saved!
Model loaded successfully!
Saving and loading models is essential for efficient machine learning workflows, enabling you to resume training without starting from scratch and share models with others.