Skip to main content
Ctrl+K
scikit-learn homepage scikit-learn homepage
  • Install
  • User Guide
  • API
  • Examples
  • Community
    • Getting Started
    • Release History
    • Glossary
    • Development
    • FAQ
    • Support
    • Related Projects
    • Roadmap
    • Governance
    • About us
  • GitHub
  • Install
  • User Guide
  • API
  • Examples
  • Community
  • Getting Started
  • Release History
  • Glossary
  • Development
  • FAQ
  • Support
  • Related Projects
  • Roadmap
  • Governance
  • About us
  • GitHub

Section Navigation

  • Release Highlights
    • Release Highlights for scikit-learn 1.6
    • Release Highlights for scikit-learn 1.5
    • Release Highlights for scikit-learn 1.4
    • Release Highlights for scikit-learn 1.3
    • Release Highlights for scikit-learn 1.2
    • Release Highlights for scikit-learn 1.1
    • Release Highlights for scikit-learn 1.0
    • Release Highlights for scikit-learn 0.24
    • Release Highlights for scikit-learn 0.23
    • Release Highlights for scikit-learn 0.22
  • Biclustering
    • A demo of the Spectral Biclustering algorithm
    • A demo of the Spectral Co-Clustering algorithm
    • Biclustering documents with the Spectral Co-clustering algorithm
  • Calibration
    • Comparison of Calibration of Classifiers
    • Probability Calibration curves
    • Probability Calibration for 3-class classification
    • Probability calibration of classifiers
  • Classification
    • Classifier comparison
    • Linear and Quadratic Discriminant Analysis with covariance ellipsoid
    • Normal, Ledoit-Wolf and OAS Linear Discriminant Analysis for classification
    • Plot classification probability
    • Recognizing hand-written digits
  • Clustering
    • A demo of K-Means clustering on the handwritten digits data
    • A demo of structured Ward hierarchical clustering on an image of coins
    • A demo of the mean-shift clustering algorithm
    • Adjustment for chance in clustering performance evaluation
    • Agglomerative clustering with and without structure
    • Agglomerative clustering with different metrics
    • An example of K-Means++ initialization
    • Bisecting K-Means and Regular K-Means Performance Comparison
    • Compare BIRCH and MiniBatchKMeans
    • Comparing different clustering algorithms on toy datasets
    • Comparing different hierarchical linkage methods on toy datasets
    • Comparison of the K-Means and MiniBatchKMeans clustering algorithms
    • Demo of DBSCAN clustering algorithm
    • Demo of HDBSCAN clustering algorithm
    • Demo of OPTICS clustering algorithm
    • Demo of affinity propagation clustering algorithm
    • Demonstration of k-means assumptions
    • Empirical evaluation of the impact of k-means initialization
    • Feature agglomeration
    • Feature agglomeration vs. univariate selection
    • Hierarchical clustering: structured vs unstructured ward
    • Inductive Clustering
    • Online learning of a dictionary of parts of faces
    • Plot Hierarchical Clustering Dendrogram
    • Segmenting the picture of greek coins in regions
    • Selecting the number of clusters with silhouette analysis on KMeans clustering
    • Spectral clustering for image segmentation
    • Various Agglomerative Clustering on a 2D embedding of digits
    • Vector Quantization Example
  • Covariance estimation
    • Ledoit-Wolf vs OAS estimation
    • Robust covariance estimation and Mahalanobis distances relevance
    • Robust vs Empirical covariance estimate
    • Shrinkage covariance estimation: LedoitWolf vs OAS and max-likelihood
    • Sparse inverse covariance estimation
  • Cross decomposition
    • Compare cross decomposition methods
    • Principal Component Regression vs Partial Least Squares Regression
  • Dataset examples
    • Plot randomly generated multilabel dataset
  • Decision Trees
    • Decision Tree Regression
    • Plot the decision surface of decision trees trained on the iris dataset
    • Post pruning decision trees with cost complexity pruning
    • Understanding the decision tree structure
  • Decomposition
    • Blind source separation using FastICA
    • Comparison of LDA and PCA 2D projection of Iris dataset
    • Faces dataset decompositions
    • Factor Analysis (with rotation) to visualize patterns
    • FastICA on 2D point clouds
    • Image denoising using dictionary learning
    • Incremental PCA
    • Kernel PCA
    • Model selection with Probabilistic PCA and Factor Analysis (FA)
    • Principal Component Analysis (PCA) on Iris Dataset
    • Sparse coding with a precomputed dictionary
  • Developing Estimators
    • __sklearn_is_fitted__ as Developer API
  • Ensemble methods
    • Categorical Feature Support in Gradient Boosting
    • Combine predictors using stacking
    • Comparing Random Forests and Histogram Gradient Boosting models
    • Comparing random forests and the multi-output meta estimator
    • Decision Tree Regression with AdaBoost
    • Early stopping in Gradient Boosting
    • Feature importances with a forest of trees
    • Feature transformations with ensembles of trees
    • Features in Histogram Gradient Boosting Trees
    • Gradient Boosting Out-of-Bag estimates
    • Gradient Boosting regression
    • Gradient Boosting regularization
    • Hashing feature transformation using Totally Random Trees
    • IsolationForest example
    • Monotonic Constraints
    • Multi-class AdaBoosted Decision Trees
    • OOB Errors for Random Forests
    • Plot class probabilities calculated by the VotingClassifier
    • Plot individual and voting regression predictions
    • Plot the decision boundaries of a VotingClassifier
    • Plot the decision surfaces of ensembles of trees on the iris dataset
    • Prediction Intervals for Gradient Boosting Regression
    • Single estimator versus bagging: bias-variance decomposition
    • Two-class AdaBoost
  • Examples based on real world datasets
    • Compressive sensing: tomography reconstruction with L1 prior (Lasso)
    • Faces recognition example using eigenfaces and SVMs
    • Image denoising using kernel PCA
    • Lagged features for time series forecasting
    • Model Complexity Influence
    • Out-of-core classification of text documents
    • Outlier detection on a real data set
    • Prediction Latency
    • Species distribution modeling
    • Time-related feature engineering
    • Topic extraction with Non-negative Matrix Factorization and Latent Dirichlet Allocation
    • Visualizing the stock market structure
    • Wikipedia principal eigenvector
  • Feature Selection
    • Comparison of F-test and mutual information
    • Model-based and sequential feature selection
    • Pipeline ANOVA SVM
    • Recursive feature elimination
    • Recursive feature elimination with cross-validation
    • Univariate Feature Selection
  • Frozen Estimators
    • Examples of Using FrozenEstimator
  • Gaussian Mixture Models
    • Concentration Prior Type Analysis of Variation Bayesian Gaussian Mixture
    • Density Estimation for a Gaussian mixture
    • GMM Initialization Methods
    • GMM covariances
    • Gaussian Mixture Model Ellipsoids
    • Gaussian Mixture Model Selection
    • Gaussian Mixture Model Sine Curve
  • Gaussian Process for Machine Learning
    • Ability of Gaussian process regression (GPR) to estimate data noise-level
    • Comparison of kernel ridge and Gaussian process regression
    • Forecasting of CO2 level on Mona Loa dataset using Gaussian process regression (GPR)
    • Gaussian Processes regression: basic introductory example
    • Gaussian process classification (GPC) on iris dataset
    • Gaussian processes on discrete data structures
    • Illustration of Gaussian process classification (GPC) on the XOR dataset
    • Illustration of prior and posterior Gaussian process for different kernels
    • Iso-probability lines for Gaussian Processes classification (GPC)
    • Probabilistic predictions with Gaussian process classification (GPC)
  • Generalized Linear Models
    • Comparing Linear Bayesian Regressors
    • Comparing various online solvers
    • Curve Fitting with Bayesian Ridge Regression
    • Decision Boundaries of Multinomial and One-vs-Rest Logistic Regression
    • Early stopping of Stochastic Gradient Descent
    • Fitting an Elastic Net with a precomputed Gram Matrix and Weighted Samples
    • HuberRegressor vs Ridge on dataset with strong outliers
    • Joint feature selection with multi-task Lasso
    • L1 Penalty and Sparsity in Logistic Regression
    • L1-based models for Sparse Signals
    • Lasso model selection via information criteria
    • Lasso model selection: AIC-BIC / cross-validation
    • Lasso on dense and sparse data
    • Lasso, Lasso-LARS, and Elastic Net paths
    • Logistic function
    • MNIST classification using multinomial logistic + L1
    • Multiclass sparse logistic regression on 20newgroups
    • Non-negative least squares
    • One-Class SVM versus One-Class SVM using Stochastic Gradient Descent
    • Ordinary Least Squares Example
    • Ordinary Least Squares and Ridge Regression Variance
    • Orthogonal Matching Pursuit
    • Plot Ridge coefficients as a function of the regularization
    • Plot multi-class SGD on the iris dataset
    • Poisson regression and non-normal loss
    • Polynomial and Spline interpolation
    • Quantile regression
    • Regularization path of L1- Logistic Regression
    • Ridge coefficients as a function of the L2 Regularization
    • Robust linear estimator fitting
    • Robust linear model estimation using RANSAC
    • SGD: Maximum margin separating hyperplane
    • SGD: Penalties
    • SGD: Weighted samples
    • SGD: convex loss functions
    • Theil-Sen Regression
    • Tweedie regression on insurance claims
  • Inspection
    • Common pitfalls in the interpretation of coefficients of linear models
    • Failure of Machine Learning to infer causal effects
    • Partial Dependence and Individual Conditional Expectation Plots
    • Permutation Importance vs Random Forest Feature Importance (MDI)
    • Permutation Importance with Multicollinear or Correlated Features
  • Kernel Approximation
    • Scalable learning with polynomial kernel approximation
  • Manifold learning
    • Comparison of Manifold Learning methods
    • Manifold Learning methods on a severed sphere
    • Manifold learning on handwritten digits: Locally Linear Embedding, Isomap…
    • Multi-dimensional scaling
    • Swiss Roll And Swiss-Hole Reduction
    • t-SNE: The effect of various perplexity values on the shape
  • Miscellaneous
    • Advanced Plotting With Partial Dependence
    • Comparing anomaly detection algorithms for outlier detection on toy datasets
    • Comparison of kernel ridge regression and SVR
    • Displaying Pipelines
    • Displaying estimators and complex pipelines
    • Evaluation of outlier detection estimators
    • Explicit feature map approximation for RBF kernels
    • Face completion with a multi-output estimators
    • Introducing the set_output API
    • Isotonic Regression
    • Metadata Routing
    • Multilabel classification
    • ROC Curve with Visualization API
    • The Johnson-Lindenstrauss bound for embedding with random projections
    • Visualizations with Display Objects
  • Missing Value Imputation
    • Imputing missing values before building an estimator
    • Imputing missing values with variants of IterativeImputer
  • Model Selection
    • Balance model complexity and cross-validated score
    • Class Likelihood Ratios to measure classification performance
    • Comparing randomized search and grid search for hyperparameter estimation
    • Comparison between grid search and successive halving
    • Confusion matrix
    • Custom refit strategy of a grid search with cross-validation
    • Demonstration of multi-metric evaluation on cross_val_score and GridSearchCV
    • Detection error tradeoff (DET) curve
    • Effect of model regularization on training and test error
    • Multiclass Receiver Operating Characteristic (ROC)
    • Nested versus non-nested cross-validation
    • Plotting Cross-Validated Predictions
    • Plotting Learning Curves and Checking Models’ Scalability
    • Post-hoc tuning the cut-off point of decision function
    • Post-tuning the decision threshold for cost-sensitive learning
    • Precision-Recall
    • Receiver Operating Characteristic (ROC) with cross validation
    • Sample pipeline for text feature extraction and evaluation
    • Statistical comparison of models using grid search
    • Successive Halving Iterations
    • Test with permutations the significance of a classification score
    • Underfitting vs. Overfitting
    • Visualizing cross-validation behavior in scikit-learn
  • Multiclass methods
    • Overview of multiclass training meta-estimators
  • Multioutput methods
    • Multilabel classification using a classifier chain
  • Nearest Neighbors
    • Approximate nearest neighbors in TSNE
    • Caching nearest neighbors
    • Comparing Nearest Neighbors with and without Neighborhood Components Analysis
    • Dimensionality Reduction with Neighborhood Components Analysis
    • Kernel Density Estimate of Species Distributions
    • Kernel Density Estimation
    • Nearest Centroid Classification
    • Nearest Neighbors Classification
    • Nearest Neighbors regression
    • Neighborhood Components Analysis Illustration
    • Novelty detection with Local Outlier Factor (LOF)
    • Outlier detection with Local Outlier Factor (LOF)
    • Simple 1D Kernel Density Estimation
  • Neural Networks
    • Compare Stochastic learning strategies for MLPClassifier
    • Restricted Boltzmann Machine features for digit classification
    • Varying regularization in Multi-layer Perceptron
    • Visualization of MLP weights on MNIST
  • Pipelines and composite estimators
    • Column Transformer with Heterogeneous Data Sources
    • Column Transformer with Mixed Types
    • Concatenating multiple feature extraction methods
    • Effect of transforming the targets in regression model
    • Pipelining: chaining a PCA and a logistic regression
    • Selecting dimensionality reduction with Pipeline and GridSearchCV
  • Preprocessing
    • Compare the effect of different scalers on data with outliers
    • Comparing Target Encoder with Other Encoders
    • Demonstrating the different strategies of KBinsDiscretizer
    • Feature discretization
    • Importance of Feature Scaling
    • Map data to a normal distribution
    • Target Encoder’s Internal Cross fitting
    • Using KBinsDiscretizer to discretize continuous features
  • Semi Supervised Classification
    • Decision boundary of semi-supervised classifiers versus SVM on the Iris dataset
    • Effect of varying threshold for self-training
    • Label Propagation digits active learning
    • Label Propagation digits: Demonstrating performance
    • Label Propagation learning a complex structure
    • Semi-supervised Classification on a Text Dataset
  • Support Vector Machines
    • One-class SVM with non-linear kernel (RBF)
    • Plot classification boundaries with different SVM Kernels
    • Plot different SVM classifiers in the iris dataset
    • Plot the support vectors in LinearSVC
    • RBF SVM parameters
    • SVM Margins Example
    • SVM Tie Breaking Example
    • SVM with custom kernel
    • SVM-Anova: SVM with univariate feature selection
    • SVM: Maximum margin separating hyperplane
    • SVM: Separating hyperplane for unbalanced classes
    • SVM: Weighted samples
    • Scaling the regularization parameter for SVCs
    • Support Vector Regression (SVR) using linear and non-linear kernels
  • Working with text documents
    • Classification of text documents using sparse features
    • Clustering text documents using k-means
    • FeatureHasher and DictVectorizer Comparison
  • Examples
  • Examples based on real world datasets
  • Time-related feature engineering

Note

Go to the end to download the full example code. or to run this example in your browser via JupyterLite or Binder

Time-related feature engineering#

This notebook introduces different strategies to leverage time-related features for a bike sharing demand regression task that is highly dependent on business cycles (days, weeks, months) and yearly season cycles.

In the process, we introduce how to perform periodic feature engineering using the sklearn.preprocessing.SplineTransformer class and its extrapolation="periodic" option.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

Data exploration on the Bike Sharing Demand dataset#

We start by loading the data from the OpenML repository.

from sklearn.datasets import fetch_openml

bike_sharing = fetch_openml("Bike_Sharing_Demand", version=2, as_frame=True)
df = bike_sharing.frame

To get a quick understanding of the periodic patterns of the data, let us have a look at the average demand per hour during a week.

Note that the week starts on a Sunday, during the weekend. We can clearly distinguish the commute patterns in the morning and evenings of the work days and the leisure use of the bikes on the weekends with a more spread peak demand around the middle of the days:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 4))
average_week_demand = df.groupby(["weekday", "hour"])["count"].mean()
average_week_demand.plot(ax=ax)
_ = ax.set(
    title="Average hourly bike demand during the week",
    xticks=[i * 24 for i in range(7)],
    xticklabels=["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
    xlabel="Time of the week",
    ylabel="Number of bike rentals",
)
Average hourly bike demand during the week

The target of the prediction problem is the absolute count of bike rentals on a hourly basis:

df["count"].max()
np.int64(977)

Let us rescale the target variable (number of hourly bike rentals) to predict a relative demand so that the mean absolute error is more easily interpreted as a fraction of the maximum demand.

Note

The fit method of the models used in this notebook all minimize the mean squared error to estimate the conditional mean. The absolute error, however, would estimate the conditional median.

Nevertheless, when reporting performance measures on the test set in the discussion, we choose to focus on the mean absolute error instead of the (root) mean squared error because it is more intuitive to interpret. Note, however, that in this study the best models for one metric are also the best ones in terms of the other metric.

y = df["count"] / df["count"].max()
fig, ax = plt.subplots(figsize=(12, 4))
y.hist(bins=30, ax=ax)
_ = ax.set(
    xlabel="Fraction of rented fleet demand",
    ylabel="Number of hours",
)
plot cyclical feature engineering

The input feature data frame is a time annotated hourly log of variables describing the weather conditions. It includes both numerical and categorical variables. Note that the time information has already been expanded into several complementary columns.

X = df.drop("count", axis="columns")
X
season year month hour holiday weekday workingday weather temp feel_temp humidity windspeed
0 spring 0 1 0 False 6 False clear 9.84 14.395 0.81 0.0000
1 spring 0 1 1 False 6 False clear 9.02 13.635 0.80 0.0000
2 spring 0 1 2 False 6 False clear 9.02 13.635 0.80 0.0000
3 spring 0 1 3 False 6 False clear 9.84 14.395 0.75 0.0000
4 spring 0 1 4 False 6 False clear 9.84 14.395 0.75 0.0000
... ... ... ... ... ... ... ... ... ... ... ... ...
17374 spring 1 12 19 False 1 True misty 10.66 12.880 0.60 11.0014
17375 spring 1 12 20 False 1 True misty 10.66 12.880 0.60 11.0014
17376 spring 1 12 21 False 1 True clear 10.66 12.880 0.60 11.0014
17377 spring 1 12 22 False 1 True clear 10.66 13.635 0.56 8.9981
17378 spring 1 12 23 False 1 True clear 10.66 13.635 0.65 8.9981

17379 rows × 12 columns



Note

If the time information was only present as a date or datetime column, we could have expanded it into hour-in-the-day, day-in-the-week, day-in-the-month, month-in-the-year using pandas: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#time-date-components

We now introspect the distribution of the categorical variables, starting with "weather":

X["weather"].value_counts()
weather
clear         11413
misty          4544
rain           1419
heavy_rain        3
Name: count, dtype: int64

Since there are only 3 "heavy_rain" events, we cannot use this category to train machine learning models with cross validation. Instead, we simplify the representation by collapsing those into the "rain" category.

X["weather"] = (
    X["weather"]
    .astype(object)
    .replace(to_replace="heavy_rain", value="rain")
    .astype("category")
)
X["weather"].value_counts()
weather
clear    11413
misty     4544
rain      1422
Name: count, dtype: int64

As expected, the "season" variable is well balanced:

X["season"].value_counts()
season
fall      4496
summer    4409
spring    4242
winter    4232
Name: count, dtype: int64

Time-based cross-validation#

Since the dataset is a time-ordered event log (hourly demand), we will use a time-sensitive cross-validation splitter to evaluate our demand forecasting model as realistically as possible. We use a gap of 2 days between the train and test side of the splits. We also limit the training set size to make the performance of the CV folds more stable.

1000 test datapoints should be enough to quantify the performance of the model. This represents a bit less than a month and a half of contiguous test data:

from sklearn.model_selection import TimeSeriesSplit

ts_cv = TimeSeriesSplit(
    n_splits=5,
    gap=48,
    max_train_size=10000,
    test_size=1000,
)

Let us manually inspect the various splits to check that the TimeSeriesSplit works as we expect, starting with the first split:

all_splits = list(ts_cv.split(X, y))
train_0, test_0 = all_splits[0]
X.iloc[test_0]
season year month hour holiday weekday workingday weather temp feel_temp humidity windspeed
12379 summer 1 6 0 False 2 True clear 22.14 25.760 0.68 27.9993
12380 summer 1 6 1 False 2 True misty 21.32 25.000 0.77 22.0028
12381 summer 1 6 2 False 2 True rain 21.32 25.000 0.72 19.9995
12382 summer 1 6 3 False 2 True rain 20.50 24.240 0.82 12.9980
12383 summer 1 6 4 False 2 True rain 20.50 24.240 0.82 12.9980
... ... ... ... ... ... ... ... ... ... ... ... ...
13374 fall 1 7 11 False 1 True clear 34.44 40.150 0.53 15.0013
13375 fall 1 7 12 False 1 True clear 34.44 39.395 0.49 8.9981
13376 fall 1 7 13 False 1 True clear 34.44 39.395 0.49 19.0012
13377 fall 1 7 14 False 1 True clear 36.08 40.910 0.42 7.0015
13378 fall 1 7 15 False 1 True clear 35.26 40.150 0.47 16.9979

1000 rows × 12 columns



X.iloc[train_0]
season year month hour holiday weekday workingday weather temp feel_temp humidity windspeed
2331 summer 0 4 1 False 2 True misty 25.42 31.060 0.50 6.0032
2332 summer 0 4 2 False 2 True misty 24.60 31.060 0.53 8.9981
2333 summer 0 4 3 False 2 True misty 23.78 27.275 0.56 8.9981
2334 summer 0 4 4 False 2 True misty 22.96 26.515 0.64 8.9981
2335 summer 0 4 5 False 2 True misty 22.14 25.760 0.68 8.9981
... ... ... ... ... ... ... ... ... ... ... ... ...
12326 summer 1 6 19 False 6 False clear 26.24 31.060 0.36 11.0014
12327 summer 1 6 20 False 6 False clear 25.42 31.060 0.35 19.0012
12328 summer 1 6 21 False 6 False clear 24.60 31.060 0.40 7.0015
12329 summer 1 6 22 False 6 False clear 23.78 27.275 0.46 8.9981
12330 summer 1 6 23 False 6 False clear 22.96 26.515 0.52 7.0015

10000 rows × 12 columns



We now inspect the last split:

train_4, test_4 = all_splits[4]
X.iloc[test_4]
season year month hour holiday weekday workingday weather temp feel_temp humidity windspeed
16379 winter 1 11 5 False 2 True misty 13.94 16.665 0.66 8.9981
16380 winter 1 11 6 False 2 True misty 13.94 16.665 0.71 11.0014
16381 winter 1 11 7 False 2 True clear 13.12 16.665 0.76 6.0032
16382 winter 1 11 8 False 2 True clear 13.94 16.665 0.71 8.9981
16383 winter 1 11 9 False 2 True misty 14.76 18.940 0.71 0.0000
... ... ... ... ... ... ... ... ... ... ... ... ...
17374 spring 1 12 19 False 1 True misty 10.66 12.880 0.60 11.0014
17375 spring 1 12 20 False 1 True misty 10.66 12.880 0.60 11.0014
17376 spring 1 12 21 False 1 True clear 10.66 12.880 0.60 11.0014
17377 spring 1 12 22 False 1 True clear 10.66 13.635 0.56 8.9981
17378 spring 1 12 23 False 1 True clear 10.66 13.635 0.65 8.9981

1000 rows × 12 columns



X.iloc[train_4]
season year month hour holiday weekday workingday weather temp feel_temp humidity windspeed
6331 winter 0 9 9 False 1 True misty 26.24 28.790 0.89 12.9980
6332 winter 0 9 10 False 1 True misty 26.24 28.790 0.89 12.9980
6333 winter 0 9 11 False 1 True clear 27.88 31.820 0.79 15.0013
6334 winter 0 9 12 False 1 True misty 27.88 31.820 0.79 11.0014
6335 winter 0 9 13 False 1 True misty 28.70 33.335 0.74 11.0014
... ... ... ... ... ... ... ... ... ... ... ... ...
16326 winter 1 11 0 False 0 False misty 12.30 15.150 0.70 11.0014
16327 winter 1 11 1 False 0 False clear 12.30 14.395 0.70 12.9980
16328 winter 1 11 2 False 0 False clear 11.48 14.395 0.81 7.0015
16329 winter 1 11 3 False 0 False misty 12.30 15.150 0.81 11.0014
16330 winter 1 11 4 False 0 False misty 12.30 14.395 0.81 12.9980

10000 rows × 12 columns



All is well. We are now ready to do some predictive modeling!

Gradient Boosting#

Gradient Boosting Regression with decision trees is often flexible enough to efficiently handle heterogeneous tabular data with a mix of categorical and numerical features as long as the number of samples is large enough.

Here, we use the modern HistGradientBoostingRegressor with native support for categorical features. Therefore, we only need to set categorical_features="from_dtype" such that features with categorical dtype are considered categorical features. For reference, we extract the categorical features from the dataframe based on the dtype. The internal trees use a dedicated tree splitting rule for these features.

The numerical variables need no preprocessing and, for the sake of simplicity, we only try the default hyper-parameters for this model:

from sklearn.compose import ColumnTransformer
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import cross_validate
from sklearn.pipeline import make_pipeline

gbrt = HistGradientBoostingRegressor(categorical_features="from_dtype", random_state=42)
categorical_columns = X.columns[X.dtypes == "category"]
print("Categorical features:", categorical_columns.tolist())
Categorical features: ['season', 'holiday', 'workingday', 'weather']

Let’s evaluate our gradient boosting model with the mean absolute error of the relative demand averaged across our 5 time-based cross-validation splits:

import numpy as np


def evaluate(model, X, y, cv, model_prop=None, model_step=None):
    cv_results = cross_validate(
        model,
        X,
        y,
        cv=cv,
        scoring=["neg_mean_absolute_error", "neg_root_mean_squared_error"],
        return_estimator=model_prop is not None,
    )
    if model_prop is not None:
        if model_step is not None:
            values = [
                getattr(m[model_step], model_prop) for m in cv_results["estimator"]
            ]
        else:
            values = [getattr(m, model_prop) for m in cv_results["estimator"]]
        print(f"Mean model.{model_prop} = {np.mean(values)}")
    mae = -cv_results["test_neg_mean_absolute_error"]
    rmse = -cv_results["test_neg_root_mean_squared_error"]
    print(
        f"Mean Absolute Error:     {mae.mean():.3f} +/- {mae.std():.3f}\n"
        f"Root Mean Squared Error: {rmse.mean():.3f} +/- {rmse.std():.3f}"
    )


evaluate(gbrt, X, y, cv=ts_cv, model_prop="n_iter_")
Mean model.n_iter_ = 100.0
Mean Absolute Error:     0.044 +/- 0.003
Root Mean Squared Error: 0.068 +/- 0.005

We see that we set max_iter large enough such that early stopping took place.

This model has an average error around 4 to 5% of the maximum demand. This is quite good for a first trial without any hyper-parameter tuning! We just had to make the categorical variables explicit. Note that the time related features are passed as is, i.e. without processing them. But this is not much of a problem for tree-based models as they can learn a non-monotonic relationship between ordinal input features and the target.

This is not the case for linear regression models as we will see in the following.

Naive linear regression#

As usual for linear models, categorical variables need to be one-hot encoded. For consistency, we scale the numerical features to the same 0-1 range using MinMaxScaler, although in this case it does not impact the results much because they are already on comparable scales:

from sklearn.linear_model import RidgeCV
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder

one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
alphas = np.logspace(-6, 6, 25)
naive_linear_pipeline = make_pipeline(
    ColumnTransformer(
        transformers=[
            ("categorical", one_hot_encoder, categorical_columns),
        ],
        remainder=MinMaxScaler(),
    ),
    RidgeCV(alphas=alphas),
)


evaluate(
    naive_linear_pipeline, X, y, cv=ts_cv, model_prop="alpha_", model_step="ridgecv"
)
Mean model.alpha_ = 2.7298221281347037
Mean Absolute Error:     0.142 +/- 0.014
Root Mean Squared Error: 0.184 +/- 0.020

It is affirmative to see that the selected alpha_ is in our specified range.

The performance is not good: the average error is around 14% of the maximum demand. This is more than three times higher than the average error of the gradient boosting model. We can suspect that the naive original encoding (merely min-max scaled) of the periodic time-related features might prevent the linear regression model to properly leverage the time information: linear regression does not automatically model non-monotonic relationships between the input features and the target. Non-linear terms have to be engineered in the input.

For example, the raw numerical encoding of the "hour" feature prevents the linear model from recognizing that an increase of hour in the morning from 6 to 8 should have a strong positive impact on the number of bike rentals while an increase of similar magnitude in the evening from 18 to 20 should have a strong negative impact on the predicted number of bike rentals.

Time-steps as categories#

Since the time features are encoded in a discrete manner using integers (24 unique values in the “hours” feature), we could decide to treat those as categorical variables using a one-hot encoding and thereby ignore any assumption implied by the ordering of the hour values.

Using one-hot encoding for the time features gives the linear model a lot more flexibility as we introduce one additional feature per discrete time level.

one_hot_linear_pipeline = make_pipeline(
    ColumnTransformer(
        transformers=[
            ("categorical", one_hot_encoder, categorical_columns),
            ("one_hot_time", one_hot_encoder, ["hour", "weekday", "month"]),
        ],
        remainder=MinMaxScaler(),
    ),
    RidgeCV(alphas=alphas),
)

evaluate(one_hot_linear_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.099 +/- 0.011
Root Mean Squared Error: 0.131 +/- 0.011

The average error rate of this model is 10% which is much better than using the original (ordinal) encoding of the time feature, confirming our intuition that the linear regression model benefits from the added flexibility to not treat time progression in a monotonic manner.

However, this introduces a very large number of new features. If the time of the day was represented in minutes since the start of the day instead of hours, one-hot encoding would have introduced 1440 features instead of 24. This could cause some significant overfitting. To avoid this we could use sklearn.preprocessing.KBinsDiscretizer instead to re-bin the number of levels of fine-grained ordinal or numerical variables while still benefitting from the non-monotonic expressivity advantages of one-hot encoding.

Finally, we also observe that one-hot encoding completely ignores the ordering of the hour levels while this could be an interesting inductive bias to preserve to some level. In the following we try to explore smooth, non-monotonic encoding that locally preserves the relative ordering of time features.

Trigonometric features#

As a first attempt, we can try to encode each of those periodic features using a sine and cosine transformation with the matching period.

Each ordinal time feature is transformed into 2 features that together encode equivalent information in a non-monotonic way, and more importantly without any jump between the first and the last value of the periodic range.

from sklearn.preprocessing import FunctionTransformer


def sin_transformer(period):
    return FunctionTransformer(lambda x: np.sin(x / period * 2 * np.pi))


def cos_transformer(period):
    return FunctionTransformer(lambda x: np.cos(x / period * 2 * np.pi))

Let us visualize the effect of this feature expansion on some synthetic hour data with a bit of extrapolation beyond hour=23:

import pandas as pd

hour_df = pd.DataFrame(
    np.arange(26).reshape(-1, 1),
    columns=["hour"],
)
hour_df["hour_sin"] = sin_transformer(24).fit_transform(hour_df)["hour"]
hour_df["hour_cos"] = cos_transformer(24).fit_transform(hour_df)["hour"]
hour_df.plot(x="hour")
_ = plt.title("Trigonometric encoding for the 'hour' feature")
Trigonometric encoding for the 'hour' feature

Let’s use a 2D scatter plot with the hours encoded as colors to better see how this representation maps the 24 hours of the day to a 2D space, akin to some sort of a 24 hour version of an analog clock. Note that the “25th” hour is mapped back to the 1st hour because of the periodic nature of the sine/cosine representation.

fig, ax = plt.subplots(figsize=(7, 5))
sp = ax.scatter(hour_df["hour_sin"], hour_df["hour_cos"], c=hour_df["hour"])
ax.set(
    xlabel="sin(hour)",
    ylabel="cos(hour)",
)
_ = fig.colorbar(sp)
plot cyclical feature engineering

We can now build a feature extraction pipeline using this strategy:

cyclic_cossin_transformer = ColumnTransformer(
    transformers=[
        ("categorical", one_hot_encoder, categorical_columns),
        ("month_sin", sin_transformer(12), ["month"]),
        ("month_cos", cos_transformer(12), ["month"]),
        ("weekday_sin", sin_transformer(7), ["weekday"]),
        ("weekday_cos", cos_transformer(7), ["weekday"]),
        ("hour_sin", sin_transformer(24), ["hour"]),
        ("hour_cos", cos_transformer(24), ["hour"]),
    ],
    remainder=MinMaxScaler(),
)
cyclic_cossin_linear_pipeline = make_pipeline(
    cyclic_cossin_transformer,
    RidgeCV(alphas=alphas),
)
evaluate(cyclic_cossin_linear_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.125 +/- 0.014
Root Mean Squared Error: 0.166 +/- 0.020

The performance of our linear regression model with this simple feature engineering is a bit better than using the original ordinal time features but worse than using the one-hot encoded time features. We will further analyze possible reasons for this disappointing outcome at the end of this notebook.

Periodic spline features#

We can try an alternative encoding of the periodic time-related features using spline transformations with a large enough number of splines, and as a result a larger number of expanded features compared to the sine/cosine transformation:

from sklearn.preprocessing import SplineTransformer


def periodic_spline_transformer(period, n_splines=None, degree=3):
    if n_splines is None:
        n_splines = period
    n_knots = n_splines + 1  # periodic and include_bias is True
    return SplineTransformer(
        degree=degree,
        n_knots=n_knots,
        knots=np.linspace(0, period, n_knots).reshape(n_knots, 1),
        extrapolation="periodic",
        include_bias=True,
    )

Again, let us visualize the effect of this feature expansion on some synthetic hour data with a bit of extrapolation beyond hour=23:

hour_df = pd.DataFrame(
    np.linspace(0, 26, 1000).reshape(-1, 1),
    columns=["hour"],
)
splines = periodic_spline_transformer(24, n_splines=12).fit_transform(hour_df)
splines_df = pd.DataFrame(
    splines,
    columns=[f"spline_{i}" for i in range(splines.shape[1])],
)
pd.concat([hour_df, splines_df], axis="columns").plot(x="hour", cmap=plt.cm.tab20b)
_ = plt.title("Periodic spline-based encoding for the 'hour' feature")
Periodic spline-based encoding for the 'hour' feature

Thanks to the use of the extrapolation="periodic" parameter, we observe that the feature encoding stays smooth when extrapolating beyond midnight.

We can now build a predictive pipeline using this alternative periodic feature engineering strategy.

It is possible to use fewer splines than discrete levels for those ordinal values. This makes spline-based encoding more efficient than one-hot encoding while preserving most of the expressivity:

cyclic_spline_transformer = ColumnTransformer(
    transformers=[
        ("categorical", one_hot_encoder, categorical_columns),
        ("cyclic_month", periodic_spline_transformer(12, n_splines=6), ["month"]),
        ("cyclic_weekday", periodic_spline_transformer(7, n_splines=3), ["weekday"]),
        ("cyclic_hour", periodic_spline_transformer(24, n_splines=12), ["hour"]),
    ],
    remainder=MinMaxScaler(),
)
cyclic_spline_linear_pipeline = make_pipeline(
    cyclic_spline_transformer,
    RidgeCV(alphas=alphas),
)
evaluate(cyclic_spline_linear_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.097 +/- 0.011
Root Mean Squared Error: 0.132 +/- 0.013

Spline features make it possible for the linear model to successfully leverage the periodic time-related features and reduce the error from ~14% to ~10% of the maximum demand, which is similar to what we observed with the one-hot encoded features.

Qualitative analysis of the impact of features on linear model predictions#

Here, we want to visualize the impact of the feature engineering choices on the time related shape of the predictions.

To do so we consider an arbitrary time-based split to compare the predictions on a range of held out data points.

naive_linear_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
naive_linear_predictions = naive_linear_pipeline.predict(X.iloc[test_0])

one_hot_linear_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
one_hot_linear_predictions = one_hot_linear_pipeline.predict(X.iloc[test_0])

cyclic_cossin_linear_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
cyclic_cossin_linear_predictions = cyclic_cossin_linear_pipeline.predict(X.iloc[test_0])

cyclic_spline_linear_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
cyclic_spline_linear_predictions = cyclic_spline_linear_pipeline.predict(X.iloc[test_0])

We visualize those predictions by zooming on the last 96 hours (4 days) of the test set to get some qualitative insights:

last_hours = slice(-96, None)
fig, ax = plt.subplots(figsize=(12, 4))
fig.suptitle("Predictions by linear models")
ax.plot(
    y.iloc[test_0].values[last_hours],
    "x-",
    alpha=0.2,
    label="Actual demand",
    color="black",
)
ax.plot(naive_linear_predictions[last_hours], "x-", label="Ordinal time features")
ax.plot(
    cyclic_cossin_linear_predictions[last_hours],
    "x-",
    label="Trigonometric time features",
)
ax.plot(
    cyclic_spline_linear_predictions[last_hours],
    "x-",
    label="Spline-based time features",
)
ax.plot(
    one_hot_linear_predictions[last_hours],
    "x-",
    label="One-hot time features",
)
_ = ax.legend()
Predictions by linear models

We can draw the following conclusions from the above plot:

  • The raw ordinal time-related features are problematic because they do not capture the natural periodicity: we observe a big jump in the predictions at the end of each day when the hour features goes from 23 back to 0. We can expect similar artifacts at the end of each week or each year.

  • As expected, the trigonometric features (sine and cosine) do not have these discontinuities at midnight, but the linear regression model fails to leverage those features to properly model intra-day variations. Using trigonometric features for higher harmonics or additional trigonometric features for the natural period with different phases could potentially fix this problem.

  • the periodic spline-based features fix those two problems at once: they give more expressivity to the linear model by making it possible to focus on specific hours thanks to the use of 12 splines. Furthermore the extrapolation="periodic" option enforces a smooth representation between hour=23 and hour=0.

  • The one-hot encoded features behave similarly to the periodic spline-based features but are more spiky: for instance they can better model the morning peak during the week days since this peak lasts shorter than an hour. However, we will see in the following that what can be an advantage for linear models is not necessarily one for more expressive models.

We can also compare the number of features extracted by each feature engineering pipeline:

naive_linear_pipeline[:-1].transform(X).shape
(17379, 19)
one_hot_linear_pipeline[:-1].transform(X).shape
(17379, 59)
cyclic_cossin_linear_pipeline[:-1].transform(X).shape
(17379, 22)
cyclic_spline_linear_pipeline[:-1].transform(X).shape
(17379, 37)

This confirms that the one-hot encoding and the spline encoding strategies create a lot more features for the time representation than the alternatives, which in turn gives the downstream linear model more flexibility (degrees of freedom) to avoid underfitting.

Finally, we observe that none of the linear models can approximate the true bike rentals demand, especially for the peaks that can be very sharp at rush hours during the working days but much flatter during the week-ends: the most accurate linear models based on splines or one-hot encoding tend to forecast peaks of commuting-related bike rentals even on the week-ends and under-estimate the commuting-related events during the working days.

These systematic prediction errors reveal a form of under-fitting and can be explained by the lack of interactions terms between features, e.g. “workingday” and features derived from “hours”. This issue will be addressed in the following section.

Modeling pairwise interactions with splines and polynomial features#

Linear models do not automatically capture interaction effects between input features. It does not help that some features are marginally non-linear as is the case with features constructed by SplineTransformer (or one-hot encoding or binning).

However, it is possible to use the PolynomialFeatures class on coarse grained spline encoded hours to model the “workingday”/”hours” interaction explicitly without introducing too many new variables:

from sklearn.pipeline import FeatureUnion
from sklearn.preprocessing import PolynomialFeatures

hour_workday_interaction = make_pipeline(
    ColumnTransformer(
        [
            ("cyclic_hour", periodic_spline_transformer(24, n_splines=8), ["hour"]),
            ("workingday", FunctionTransformer(lambda x: x == "True"), ["workingday"]),
        ]
    ),
    PolynomialFeatures(degree=2, interaction_only=True, include_bias=False),
)

Those features are then combined with the ones already computed in the previous spline-base pipeline. We can observe a nice performance improvement by modeling this pairwise interaction explicitly:

cyclic_spline_interactions_pipeline = make_pipeline(
    FeatureUnion(
        [
            ("marginal", cyclic_spline_transformer),
            ("interactions", hour_workday_interaction),
        ]
    ),
    RidgeCV(alphas=alphas),
)
evaluate(cyclic_spline_interactions_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.078 +/- 0.009
Root Mean Squared Error: 0.104 +/- 0.009

Modeling non-linear feature interactions with kernels#

The previous analysis highlighted the need to model the interactions between "workingday" and "hours". Another example of a such a non-linear interaction that we would like to model could be the impact of the rain that might not be the same during the working days and the week-ends and holidays for instance.

To model all such interactions, we could either use a polynomial expansion on all marginal features at once, after their spline-based expansion. However, this would create a quadratic number of features which can cause overfitting and computational tractability issues.

Alternatively, we can use the Nyström method to compute an approximate polynomial kernel expansion. Let us try the latter:

from sklearn.kernel_approximation import Nystroem

cyclic_spline_poly_pipeline = make_pipeline(
    cyclic_spline_transformer,
    Nystroem(kernel="poly", degree=2, n_components=300, random_state=0),
    RidgeCV(alphas=alphas),
)
evaluate(cyclic_spline_poly_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.053 +/- 0.002
Root Mean Squared Error: 0.076 +/- 0.004

We observe that this model can almost rival the performance of the gradient boosted trees with an average error around 5% of the maximum demand.

Note that while the final step of this pipeline is a linear regression model, the intermediate steps such as the spline feature extraction and the Nyström kernel approximation are highly non-linear. As a result the compound pipeline is much more expressive than a simple linear regression model with raw features.

For the sake of completeness, we also evaluate the combination of one-hot encoding and kernel approximation:

one_hot_poly_pipeline = make_pipeline(
    ColumnTransformer(
        transformers=[
            ("categorical", one_hot_encoder, categorical_columns),
            ("one_hot_time", one_hot_encoder, ["hour", "weekday", "month"]),
        ],
        remainder="passthrough",
    ),
    Nystroem(kernel="poly", degree=2, n_components=300, random_state=0),
    RidgeCV(alphas=alphas),
)
evaluate(one_hot_poly_pipeline, X, y, cv=ts_cv)
Mean Absolute Error:     0.082 +/- 0.006
Root Mean Squared Error: 0.111 +/- 0.011

While one-hot encoded features were competitive with spline-based features when using linear models, this is no longer the case when using a low-rank approximation of a non-linear kernel: this can be explained by the fact that spline features are smoother and allow the kernel approximation to find a more expressive decision function.

Let us now have a qualitative look at the predictions of the kernel models and of the gradient boosted trees that should be able to better model non-linear interactions between features:

gbrt.fit(X.iloc[train_0], y.iloc[train_0])
gbrt_predictions = gbrt.predict(X.iloc[test_0])

one_hot_poly_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
one_hot_poly_predictions = one_hot_poly_pipeline.predict(X.iloc[test_0])

cyclic_spline_poly_pipeline.fit(X.iloc[train_0], y.iloc[train_0])
cyclic_spline_poly_predictions = cyclic_spline_poly_pipeline.predict(X.iloc[test_0])

Again we zoom on the last 4 days of the test set:

last_hours = slice(-96, None)
fig, ax = plt.subplots(figsize=(12, 4))
fig.suptitle("Predictions by non-linear regression models")
ax.plot(
    y.iloc[test_0].values[last_hours],
    "x-",
    alpha=0.2,
    label="Actual demand",
    color="black",
)
ax.plot(
    gbrt_predictions[last_hours],
    "x-",
    label="Gradient Boosted Trees",
)
ax.plot(
    one_hot_poly_predictions[last_hours],
    "x-",
    label="One-hot + polynomial kernel",
)
ax.plot(
    cyclic_spline_poly_predictions[last_hours],
    "x-",
    label="Splines + polynomial kernel",
)
_ = ax.legend()
Predictions by non-linear regression models

First, note that trees can naturally model non-linear feature interactions since, by default, decision trees are allowed to grow beyond a depth of 2 levels.

Here, we can observe that the combinations of spline features and non-linear kernels works quite well and can almost rival the accuracy of the gradient boosting regression trees.

On the contrary, one-hot encoded time features do not perform that well with the low rank kernel model. In particular, they significantly over-estimate the low demand hours more than the competing models.

We also observe that none of the models can successfully predict some of the peak rentals at the rush hours during the working days. It is possible that access to additional features would be required to further improve the accuracy of the predictions. For instance, it could be useful to have access to the geographical repartition of the fleet at any point in time or the fraction of bikes that are immobilized because they need servicing.

Let us finally get a more quantitative look at the prediction errors of those three models using the true vs predicted demand scatter plots:

from sklearn.metrics import PredictionErrorDisplay

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(13, 7), sharex=True, sharey="row")
fig.suptitle("Non-linear regression models", y=1.0)
predictions = [
    one_hot_poly_predictions,
    cyclic_spline_poly_predictions,
    gbrt_predictions,
]
labels = [
    "One hot +\npolynomial kernel",
    "Splines +\npolynomial kernel",
    "Gradient Boosted\nTrees",
]
plot_kinds = ["actual_vs_predicted", "residual_vs_predicted"]
for axis_idx, kind in enumerate(plot_kinds):
    for ax, pred, label in zip(axes[axis_idx], predictions, labels):
        disp = PredictionErrorDisplay.from_predictions(
            y_true=y.iloc[test_0],
            y_pred=pred,
            kind=kind,
            scatter_kwargs={"alpha": 0.3},
            ax=ax,
        )
        ax.set_xticks(np.linspace(0, 1, num=5))
        if axis_idx == 0:
            ax.set_yticks(np.linspace(0, 1, num=5))
            ax.legend(
                ["Best model", label],
                loc="upper center",
                bbox_to_anchor=(0.5, 1.3),
                ncol=2,
            )
        ax.set_aspect("equal", adjustable="box")
plt.show()
Non-linear regression models

This visualization confirms the conclusions we draw on the previous plot.

All models under-estimate the high demand events (working day rush hours), but gradient boosting a bit less so. The low demand events are well predicted on average by gradient boosting while the one-hot polynomial regression pipeline seems to systematically over-estimate demand in that regime. Overall the predictions of the gradient boosted trees are closer to the diagonal than for the kernel models.

Concluding remarks#

We note that we could have obtained slightly better results for kernel models by using more components (higher rank kernel approximation) at the cost of longer fit and prediction durations. For large values of n_components, the performance of the one-hot encoded features would even match the spline features.

The Nystroem + RidgeCV regressor could also have been replaced by MLPRegressor with one or two hidden layers and we would have obtained quite similar results.

The dataset we used in this case study is sampled on a hourly basis. However cyclic spline-based features could model time-within-day or time-within-week very efficiently with finer-grained time resolutions (for instance with measurements taken every minute instead of every hours) without introducing more features. One-hot encoding time representations would not offer this flexibility.

Finally, in this notebook we used RidgeCV because it is very efficient from a computational point of view. However, it models the target variable as a Gaussian random variable with constant variance. For positive regression problems, it is likely that using a Poisson or Gamma distribution would make more sense. This could be achieved by using GridSearchCV(TweedieRegressor(power=2), param_grid({"alpha": alphas})) instead of RidgeCV.

Total running time of the script: (0 minutes 19.632 seconds)

Launch binder
Launch JupyterLite

Download Jupyter notebook: plot_cyclical_feature_engineering.ipynb

Download Python source code: plot_cyclical_feature_engineering.py

Download zipped: plot_cyclical_feature_engineering.zip

Related examples

Lagged features for time series forecasting

Lagged features for time series forecasting

Categorical Feature Support in Gradient Boosting

Categorical Feature Support in Gradient Boosting

Polynomial and Spline interpolation

Polynomial and Spline interpolation

Partial Dependence and Individual Conditional Expectation Plots

Partial Dependence and Individual Conditional Expectation Plots

Gallery generated by Sphinx-Gallery

previous

Species distribution modeling

next

Topic extraction with Non-negative Matrix Factorization and Latent Dirichlet Allocation

On this page
  • Data exploration on the Bike Sharing Demand dataset
  • Time-based cross-validation
  • Gradient Boosting
  • Naive linear regression
  • Time-steps as categories
  • Trigonometric features
  • Periodic spline features
  • Qualitative analysis of the impact of features on linear model predictions
  • Modeling pairwise interactions with splines and polynomial features
  • Modeling non-linear feature interactions with kernels
  • Concluding remarks

This Page

  • Show Source
Download source code
Download Jupyter notebook
Download zipped
Launch JupyterLite
Launch binder

© Copyright 2007 - 2025, scikit-learn developers (BSD License).