Categorical Cross-Entropy in Multi-Class Classification
Last Updated :
17 Sep, 2024
Categorical Cross-Entropy (CCE), also known as softmax loss or log loss, is one of the most commonly used loss functions in machine learning, particularly for classification problems. It measures the difference between the predicted probability distribution and the actual (true) distribution of classes. The function helps a machine learning model determine how far its predictions are from the true labels and guides it in learning to make more accurate predictions.
In this article, we will explore mathematical representation, working, and application of Categorical Cross-Entropy.
Introduction to Loss Functions
In machine learning, the goal of training a model is to minimize the error in its predictions. To do this, models use a loss function, which calculates how well the model’s predictions match the actual values. The lower the value of the loss function, the better the model is performing. For classification tasks, cross-entropy is a popular choice due to its effectiveness in quantifying the performance of a classification model.
Understanding Categorical Cross-Entropy
Categorical cross-entropy is used when you have more than two classes in your classification problem (multi-class classification). It measures the difference between two probability distributions: the predicted probability distribution and the true distribution, which is represented by a one-hot encoded vector.
In a one-hot encoded vector, the correct class is represented as "1" and all other classes as "0." Categorical cross-entropy penalizes predictions based on how confident the model is about the correct class.
If the model assigns a high probability to the true class, the cross-entropy will be low. Conversely, if the model assigns low probability to the correct class, the cross-entropy will be high.
Mathematical Representation of Categorical Cross-Entropy
The categorical cross-entropy formula is expressed as:
L(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)
Where:
- L(y, \hat{y}) is the categorical cross-entropy loss.
- y_i is the true label (0 or 1 for each class) from the one-hot encoded target vector.
- \hat{y}_i is the predicted probability for class i.
- C is the number of classes.
In this formula, the logarithm ensures that incorrect predictions are heavily penalized.
Example : Calculating Categorical Cross-Entropy
Let's break down the categorical cross-entropy calculation with a mathematical example using the following true labels and predicted probabilities.
We have 3 samples, each belonging to one of 3 classes (Class 1, Class 2, or Class 3). The true labels are one-hot encoded.
- True Labels (y_true):
- Example 1: Class 2 →
[0, 1, 0]
- Example 2: Class 1 →
[1, 0, 0]
- Example 3: Class 3 →
[0, 0, 1]
- Predicted Probabilities (y_pred):
- Example 1:
[0.1, 0.8, 0.1]
- Example 2:
[0.7, 0.2, 0.1]
- Example 3:
[0.2, 0.3, 0.5]
Step-by-Step Calculation
Example 1: True Label [0, 1, 0]
, Predicted [0.1, 0.8, 0.1]
The true class is Class 2, so y_2 = 1, and we focus on the predicted probability for Class 2, which is \hat{y}_2 = 0.8.
L_1 = -\left( 0 \cdot \log(0.1) + 1 \cdot \log(0.8) + 0 \cdot \log(0.1) \right)
Simplifying:
L_1 = -\log(0.8) = -(-0.22314355) = 0.22314355
Example 2: True Label [1, 0, 0]
, Predicted [0.7, 0.2, 0.1]
The true class is Class 1, so y_1 = 1, and we focus on the predicted probability for Class 1, which is \hat{y}_1 = 0.7.
L_2 = -\left( 1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1) \right)
Simplifying:
L_2 = -\log(0.7) = -(-0.35667494) = 0.35667494
Example 3: True Label [0, 0, 1]
, Predicted [0.2, 0.3, 0.5]
The true class is Class 3, so y_3 = 1, and we focus on the predicted probability for Class 3, which is \hat{y}_3 = 0.5.
L_3 = -\left( 0 \cdot \log(0.2) + 0 \cdot \log(0.3) + 1 \cdot \log(0.5) \right)
Simplifying:
L_3 = -\log(0.5) = -(-0.69314718) = 0.69314718
Final Losses:
- For Example 1, the loss is: 0.22314355
- For Example 2, the loss is: 0.35667494
- For Example 3, the loss is: 0.69314718
Thus, the total categorical cross-entropy loss values are:
\text{Loss}: [0.22314355, 0.35667494, 0.69314718]
This loss function is crucial in guiding the model to learn better during training by adjusting its weights to minimize the error.
How Categorical Cross-Entropy Works
To understand how CCE works, let's break it down:
- Prediction of Probabilities: The model outputs probabilities for each class. These probabilities are the likelihood of a data point belonging to each class. Typically, this is done using a softmax function, which converts raw scores into probabilities.
- Comparison with True Class: Categorical cross-entropy compares the predicted probabilities with the actual class labels (one-hot encoded).
- Calculation of Loss: The logarithm of the predicted probability for the correct class is taken, and the loss function penalizes the model based on how far the prediction was from the actual class.
For example, if the true label is class 1, and the predicted probability for class 1 is 0.9, the categorical cross-entropy loss will be small. If the predicted probability is 0.1, the loss will be much larger, forcing the model to correct its weights.
Application of Categorical Cross-Entropy in Multi-Class Classification
Categorical cross-entropy is essential in multi-class classification, where a model must classify an instance into one of several classes. For example, in an image classification task, the model might need to identify whether an image is of a cat, dog, or bird. CCE helps the model adjust its weights during training to make better predictions.
It's important to note that the CCE loss function assumes that each data point belongs to exactly one class. If you have a problem where a data point can belong to multiple classes simultaneously, binary cross-entropy would be a better choice.
Differences Between Categorical and Binary Cross-Entropy
While both binary and categorical cross-entropy are used to calculate loss in classification problems, they differ in use cases and how they handle multiple classes:
- Binary Cross-Entropy is used for binary classification problems where there are only two possible outcomes (e.g., "yes" or "no").
- Categorical Cross-Entropy is used for multi-class classification where there are three or more categories, and the model assigns probabilities to each.
The key distinction lies in the number of classes the model is predicting and how those classes are encoded in the target labels.
Implementing Categorical Cross-Entropy in Python
Implementing categorical cross-entropy in Python, especially with libraries like TensorFlow or PyTorch, is straightforward since these libraries have built-in functions to handle this.
Here’s an example in TensorFlow:
Python
import tensorflow as tf
import numpy as np
# True labels (one-hot encoded)
y_true = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
# Predicted probabilities
y_pred = np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1], [0.2, 0.3, 0.5]])
# Categorical Cross-Entropy loss calculation
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
print("Loss:", loss.numpy())
Output:
Loss: [0.22314355 0.35667494 0.69314718]
The output Loss: [0.22314355 0.35667494 0.69314718]
represents the categorical cross-entropy loss for each of the three examples in the provided dataset.
Conclusion
Categorical cross-entropy is a powerful loss function commonly used in multi-class classification problems. By comparing the predicted probabilities to the true one-hot encoded labels, it guides the model’s learning process, pushing it to make better predictions. Understanding how to use CCE and implementing it correctly can significantly impact the performance of your classification models.
Similar Reads
Binary Cross Entropy/Log Loss for Binary Classification
Binary cross-entropy (log loss) is a loss function used in binary classification problems. It quantifies the difference between the actual class labels (0 or 1) and the predicted probabilities output by the model. The lower the binary cross-entropy value, the better the modelâs predictions align wit
4 min read
Cross-Entropy Cost Functions used in Classification
Cost functions play a crucial role in improving a Machine Learning model's performance by being an integrated part of the gradient descent algorithm which helps us optimize the weights of a particular model. In this article, we will learn about one such cost function which is the cross-entropy funct
4 min read
How does KNN handle multi-class classification problems?
K-Nearest Neighbors (KNN) stands as a fundamental algorithm, wielding versatility in handling both classification and regression tasks. In this article, we will understand what are KNNs and how they handle multi-classification problems. What are k-nearest neighbors (KNNs)?K-Nearest Neighbors (KNN) c
5 min read
Sparse Categorical Crossentropy vs. Categorical Crossentropy
When training a machine learning model for classification tasks, selecting the right loss function is critical. Two widely used loss functions for multi-class classification problems are Categorical Crossentropy and Sparse Categorical Crossentropy. What is Categorical Crossentropy?Categorical Crosse
3 min read
Probability Calibration for 3-class Classification in Scikit Learn
Probability calibration is a technique to map the predicted probabilities of a model to their true probabilities. The probabilities predicted by some classification algorithms like Logistic Regression, SVM, or Random Forest may not be well calibrated, meaning they may not accurately reflect the true
4 min read
How to fit categorical data types for random forest classification?
Categorical variables are an essential component of many datasets, representing qualitative characteristics rather than numerical values. While random forest classification is a powerful machine-learning technique, it typically requires numerical input data. Therefore, encoding categorical variables
9 min read
Toxic Comment Classification using BERT
Social media users frequently encounter abuse, harassment, and insults from other users on a majority of online communication platforms like Facebook, Instagram and Youtube due to which many users stop expressing their ideas and opinions. What is the solution? The solution to this problem is to crea
15+ min read
ML | Cancer cell classification using Scikit-learn
Machine learning is used in solving real-world problems including medical diagnostics. One such application is classifying cancer cells based on their features and determining whether they are 'malignant' or 'benign'. In this article, we will use Scikit-learn to build a classifier for cancer cell de
3 min read
An introduction to MultiLabel classification
One of the most used capabilities of supervised machine learning techniques is for classifying content, employed in many contexts like telling if a given restaurant review is positive or negative or inferring if there is a cat or a dog on an image. This task may be divided into three domains, binary
7 min read
Perceptron Algorithm for Classification using Sklearn
Assigning a label or category to an input based on its features is the fundamental task of classification in machine learning. One of the earliest and most straightforward machine learning techniques for binary classification is the perceptron. It serves as the framework for more sophisticated neura
11 min read