Save and load models in Tensorflow
Last Updated :
08 Apr, 2025
Training machine learning or deep learning model is time-consuming and shutting down the notebook causes all the weights and activations to disappear as the memory is flushed. Hence, we save models for reusability, collaboration, and continuation of training.
- Saving the model allows us to avoid lengthy training periods and enables others to replicate the model.
- It also allows you to share the model with others so they can replicate your results.
- When sharing machine learning models it’s common to include the following:
- code to create the model
- trained weights for the model
Below are the methods for saving and loading machine learning models in TensorFlow.

Methods to Save and Load Models
Here are the methods that can be used to save model.
1. Using the save()
Method
The save()
method allows you to save the complete model including:
- Model architecture
- Model weights
- Model optimizer state to resume training from where you left off.
tensorflow.keras.X.save('location/model_name')
Where X
can be a Sequential, Functional Model or Model subclass. The location
specifies where the model is stored and if no path is specified it will be saved in the same location as the Python file.
To load the model use theload_model()
method:
tensorflow.keras.models.load_model('location/model_name')
2. Using the save_weights()
Method
In some cases you might want to save just the weights of the model instead of the entire model. This can be done using the save_weights()
method which saves the weights of all the layers in the model.
tensorflow.keras.Model.save_weights('location/weights_name')
The weights_name
is the file name for the saved weights and if no path is provided it is saved in the same location as the Python file.
To load the saved weights use the load_weights
()
method:
tensorflow.keras.Model.load_weights('location/weights_name')
Note: When loading weights ensure that the model's architecture is the same as the one used to save the weights. For example you cannot load the weights of a model with two dense layers into a model with just one dense layer.
3. HDF5 Format (.h5
)
If you save your model with the .h5
extension the model is saved in HDF5 format. This format is portable and commonly used for storing large data and models. You can specify the .h5
extension when saving the model and TensorFlow will automatically save the model in this format.
model.save('my_model.h5')
If you don’t specify the extension, TensorFlow saves the model in its native format.
Code to Save and Load Models
Here we will build a neural network and then save it.
1. Import Necessary Module
We will importing tenserflow for model making.
Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D
and
BatchNormalization
is imported
to build neural networks.Model
to define the model architecture.load_model
to load saved models.
Python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model, load_model
2. Load and Preprocess Data
- Here we will use CIFAR-10 dataset which contains 60,000 images with 32x32 size in 10 classes and 50,000 images will be used for training and 10,000 for testing.
Python
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.flatten(), y_test.flatten()
3. Defining the Model
The model contains following layers:
Python
K = len(set(y_train))
i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
x = BatchNormalization()(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dropout(0.2)(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)
x = Dense(K, activation='softmax')(x)
model = Model(i, x)
model.summary()
Output:
Model Summary4. Saving and Loading the Model
Python
model.save('my_model.h5')
print("Model saved!")
saved_model = load_model('my_model.h5')
if saved_model is not None:
print("Model loaded successfully!")
else:
print("Failed to load the model.")
Output:
Model saved!
Model loaded successfully!
Saving and loading models is essential for efficient machine learning workflows, enabling you to resume training without starting from scratch and share models with others.
Similar Reads
Python Tutorial | Learn Python Programming Language
Python Tutorial â Python is one of the most popular programming languages. Itâs simple to use, packed with features and supported by a wide range of libraries and frameworks. Its clean syntax makes it beginner-friendly.Python is:A high-level language, used in web development, data science, automatio
10 min read
Machine Learning Tutorial
Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.It can
5 min read
Non-linear Components
In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Linear Regression in Machine learning
Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea
15+ min read
Support Vector Machine (SVM) Algorithm
Support Vector Machine (SVM) is a supervised machine learning algorithm used for classification and regression tasks. While it can handle regression problems, SVM is particularly well-suited for classification tasks. SVM aims to find the optimal hyperplane in an N-dimensional space to separate data
10 min read
Class Diagram | Unified Modeling Language (UML)
A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
K means Clustering â Introduction
K-Means Clustering is an Unsupervised Machine Learning algorithm which groups unlabeled dataset into different clusters. It is used to organize data into groups based on their similarity. Understanding K-means ClusteringFor example online store uses K-Means to group customers based on purchase frequ
4 min read
Spring Boot Tutorial
Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Logistic Regression in Machine Learning
In our previous discussion, we explored the fundamentals of machine learning and walked through a hands-on implementation of Linear Regression. Now, let's take a step forward and dive into one of the first and most widely used classification algorithms â Logistic RegressionWhat is Logistic Regressio
12 min read
K-Nearest Neighbor(KNN) Algorithm
K-Nearest Neighbors (KNN) is a supervised machine learning algorithm generally used for classification but can also be used for regression tasks. It works by finding the "k" closest data points (neighbors) to a given input and makesa predictions based on the majority class (for classification) or th
8 min read