{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "6XvCUmCEd4Dm"
},
"source": [
"# TensorFlow Datasets\n",
"\n",
"TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.\n",
"\n",
"It handles downloading and preparing the data deterministically and constructing a `tf.data.Dataset` (or `np.array`).\n",
"\n",
"Note: Do not confuse [TFDS](https://www.tensorflow.org/datasets) (this library) with `tf.data` (TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around `tf.data`. If you're not familiar with this API, we encourage you to read [the official tf.data guide](https://www.tensorflow.org/guide/data) first.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J8y9ZkLXmAZc"
},
"source": [
"Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OGw9EgE0tC0C"
},
"source": [
"
"
],
"text/plain": [
"BenchmarkResult(stats= duration num_examples avg\n",
"first+lasts 0.300411 60032 199832.925534\n",
"first 0.013519 32 2367.040208\n",
"lasts 0.286892 60000 209137.961140, raw_stats= duration\n",
"start_time 2465.173582\n",
"first_batch_time 2465.187101\n",
"end_time 2465.473993\n",
"num_iter 1875.000000)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds = tfds.load('mnist', split='train')\n",
"ds = ds.batch(32).prefetch(1)\n",
"\n",
"tfds.benchmark(ds, batch_size=32)\n",
"tfds.benchmark(ds, batch_size=32) # Second epoch much faster due to auto-caching"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MT0yEX_4kYnV"
},
"source": [
"* Do not forget to normalize the results per batch size with the `batch_size=` kwarg.\n",
"* In the summary, the first warmup batch is separated from the other ones to capture `tf.data.Dataset` extra setup time (e.g. buffers initialization,...).\n",
"* Notice how the second iteration is much faster due to [TFDS auto-caching](https://www.tensorflow.org/datasets/performances#auto-caching).\n",
"* `tfds.benchmark` returns a `tfds.core.BenchmarkResult` which can be inspected for further analysis."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o-cuwvVbeb43"
},
"source": [
"### Build end-to-end pipeline\n",
"\n",
"To go further, you can look:\n",
"\n",
"* Our [end-to-end Keras example](https://www.tensorflow.org/datasets/keras_example) to see a full training pipeline (with batching, shuffling,...).\n",
"* Our [performance guide](https://www.tensorflow.org/datasets/performances) to improve the speed of your pipelines (tip: use `tfds.benchmark(ds)` to benchmark your datasets).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gTRTEQqscxAE"
},
"source": [
"## Visualization\n",
"\n",
"### tfds.as_dataframe\n",
"\n",
"`tf.data.Dataset` objects can be converted to [`pandas.DataFrame`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) with `tfds.as_dataframe` to be visualized on [Colab](https://colab.research.google.com).\n",
"\n",
"* Add the `tfds.core.DatasetInfo` as second argument of `tfds.as_dataframe` to visualize images, audio, texts, videos,...\n",
"* Use `ds.take(x)` to only display the first `x` examples. `pandas.DataFrame` will load the full dataset in-memory, and can be very expensive to display."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:04.926139Z",
"iopub.status.busy": "2024-12-14T12:42:04.925607Z",
"iopub.status.idle": "2024-12-14T12:42:06.157855Z",
"shell.execute_reply": "2024-12-14T12:42:06.157137Z"
},
"id": "FKouwN_yVSGQ"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-12-14 12:42:05.775782: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
]
},
{
"data": {
"text/html": [
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
image
\n",
"
label
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
\n",
"
4
\n",
"
\n",
"
\n",
"
1
\n",
"
\n",
"
1
\n",
"
\n",
"
\n",
"
2
\n",
"
\n",
"
0
\n",
"
\n",
"
\n",
"
3
\n",
"
\n",
"
7
\n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
" image label\n",
"0 [[[0], [0], [0], [0], [0], [0], [0], [0], [0],... 4\n",
"1 [[[0], [0], [0], [0], [0], [0], [0], [0], [0],... 1\n",
"2 [[[0], [0], [0], [0], [0], [0], [0], [0], [0],... 0\n",
"3 [[[0], [0], [0], [0], [0], [0], [0], [0], [0],... 7"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds, info = tfds.load('mnist', split='train', with_info=True)\n",
"\n",
"tfds.as_dataframe(ds.take(4), info)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b-eDO_EXVGWC"
},
"source": [
"### tfds.show_examples\n",
"\n",
"`tfds.show_examples` returns a `matplotlib.figure.Figure` (only image datasets supported now):"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:06.161149Z",
"iopub.status.busy": "2024-12-14T12:42:06.160454Z",
"iopub.status.idle": "2024-12-14T12:42:07.144860Z",
"shell.execute_reply": "2024-12-14T12:42:07.144136Z"
},
"id": "DpE2FD56cSQR"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-12-14 12:42:06.859043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ds, info = tfds.load('mnist', split='train', with_info=True)\n",
"\n",
"fig = tfds.show_examples(ds, info)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0iVVStvk0oI"
},
"source": [
"## Access the dataset metadata\n",
"\n",
"All builders include a `tfds.core.DatasetInfo` object containing the dataset metadata.\n",
"\n",
"It can be accessed through:\n",
"\n",
"* The `tfds.load` API:\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:07.147922Z",
"iopub.status.busy": "2024-12-14T12:42:07.147344Z",
"iopub.status.idle": "2024-12-14T12:42:07.609154Z",
"shell.execute_reply": "2024-12-14T12:42:07.608367Z"
},
"id": "UgLgtcd1ljzt"
},
"outputs": [],
"source": [
"ds, info = tfds.load('mnist', with_info=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XodyqNXrlxTM"
},
"source": [
"* The `tfds.core.DatasetBuilder` API:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:07.612570Z",
"iopub.status.busy": "2024-12-14T12:42:07.612038Z",
"iopub.status.idle": "2024-12-14T12:42:08.008414Z",
"shell.execute_reply": "2024-12-14T12:42:08.007595Z"
},
"id": "nmq97QkilxeL"
},
"outputs": [],
"source": [
"builder = tfds.builder('mnist')\n",
"info = builder.info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMGOk_ZsmPeu"
},
"source": [
"The dataset info contains additional informations about the dataset (version, citation, homepage, description,...)."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.011773Z",
"iopub.status.busy": "2024-12-14T12:42:08.011185Z",
"iopub.status.idle": "2024-12-14T12:42:08.015475Z",
"shell.execute_reply": "2024-12-14T12:42:08.014814Z"
},
"id": "O-wLIKD-mZQT"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tfds.core.DatasetInfo(\n",
" name='mnist',\n",
" full_name='mnist/3.0.1',\n",
" description=\"\"\"\n",
" The MNIST database of handwritten digits.\n",
" \"\"\",\n",
" homepage='http://yann.lecun.com/exdb/mnist/',\n",
" data_dir='gs://tensorflow-datasets/datasets/mnist/3.0.1',\n",
" file_format=tfrecord,\n",
" download_size=11.06 MiB,\n",
" dataset_size=21.00 MiB,\n",
" features=FeaturesDict({\n",
" 'image': Image(shape=(28, 28, 1), dtype=uint8),\n",
" 'label': ClassLabel(shape=(), dtype=int64, num_classes=10),\n",
" }),\n",
" supervised_keys=('image', 'label'),\n",
" disable_shuffling=False,\n",
" splits={\n",
" 'test': ,\n",
" 'train': ,\n",
" },\n",
" citation=\"\"\"@article{lecun2010mnist,\n",
" title={MNIST handwritten digit database},\n",
" author={LeCun, Yann and Cortes, Corinna and Burges, CJ},\n",
" journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},\n",
" volume={2},\n",
" year={2010}\n",
" }\"\"\",\n",
")\n"
]
}
],
"source": [
"print(info)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1zvAfRtwnAFk"
},
"source": [
"### Features metadata (label names, image shape,...)\n",
"\n",
"Access the `tfds.features.FeatureDict`:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.018310Z",
"iopub.status.busy": "2024-12-14T12:42:08.017780Z",
"iopub.status.idle": "2024-12-14T12:42:08.022366Z",
"shell.execute_reply": "2024-12-14T12:42:08.021732Z"
},
"id": "RcyZXncqoFab"
},
"outputs": [
{
"data": {
"text/plain": [
"FeaturesDict({\n",
" 'image': Image(shape=(28, 28, 1), dtype=uint8),\n",
" 'label': ClassLabel(shape=(), dtype=int64, num_classes=10),\n",
"})"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"info.features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KAm9AV7loyw5"
},
"source": [
"Number of classes, label names:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.025203Z",
"iopub.status.busy": "2024-12-14T12:42:08.024690Z",
"iopub.status.idle": "2024-12-14T12:42:08.029083Z",
"shell.execute_reply": "2024-12-14T12:42:08.028430Z"
},
"id": "HhfzBH6qowpz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10\n",
"['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
"7\n",
"7\n"
]
}
],
"source": [
"print(info.features[\"label\"].num_classes)\n",
"print(info.features[\"label\"].names)\n",
"print(info.features[\"label\"].int2str(7)) # Human readable version (8 -> 'cat')\n",
"print(info.features[\"label\"].str2int('7'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g5eWtk9ro_AK"
},
"source": [
"Shapes, dtypes:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.031764Z",
"iopub.status.busy": "2024-12-14T12:42:08.031338Z",
"iopub.status.idle": "2024-12-14T12:42:08.039042Z",
"shell.execute_reply": "2024-12-14T12:42:08.038359Z"
},
"id": "SergV_wQowLY"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:`FeatureConnector.dtype` is deprecated. Please change your code to use NumPy with the field `FeatureConnector.np_dtype` or use TensorFlow with the field `FeatureConnector.tf_dtype`.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:`FeatureConnector.dtype` is deprecated. Please change your code to use NumPy with the field `FeatureConnector.np_dtype` or use TensorFlow with the field `FeatureConnector.tf_dtype`.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'image': (28, 28, 1), 'label': ()}\n",
"{'image': tf.uint8, 'label': tf.int64}\n",
"(28, 28, 1)\n",
"\n"
]
}
],
"source": [
"print(info.features.shape)\n",
"print(info.features.dtype)\n",
"print(info.features['image'].shape)\n",
"print(info.features['image'].dtype)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "thMOZ4IKm55N"
},
"source": [
"### Split metadata (e.g. split names, number of examples,...)\n",
"\n",
"Access the `tfds.core.SplitDict`:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.041835Z",
"iopub.status.busy": "2024-12-14T12:42:08.041327Z",
"iopub.status.idle": "2024-12-14T12:42:08.044950Z",
"shell.execute_reply": "2024-12-14T12:42:08.044294Z"
},
"id": "FBbfwA8Sp4ax"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'test': , 'train': }\n"
]
}
],
"source": [
"print(info.splits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EVw1UVYa2HgN"
},
"source": [
"Available splits:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.047658Z",
"iopub.status.busy": "2024-12-14T12:42:08.047126Z",
"iopub.status.idle": "2024-12-14T12:42:08.050901Z",
"shell.execute_reply": "2024-12-14T12:42:08.050227Z"
},
"id": "fRBieOOquDzX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['test', 'train']\n"
]
}
],
"source": [
"print(list(info.splits.keys()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iHW0VfA0t3dO"
},
"source": [
"Get info on individual split:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.053692Z",
"iopub.status.busy": "2024-12-14T12:42:08.053151Z",
"iopub.status.idle": "2024-12-14T12:42:08.057249Z",
"shell.execute_reply": "2024-12-14T12:42:08.056605Z"
},
"id": "-h_OSpRsqKpP"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"60000\n",
"['mnist-train.tfrecord-00000-of-00001']\n",
"1\n"
]
}
],
"source": [
"print(info.splits['train'].num_examples)\n",
"print(info.splits['train'].filenames)\n",
"print(info.splits['train'].num_shards)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fWhSkHFNuLwW"
},
"source": [
"It also works with the subsplit API:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-12-14T12:42:08.060009Z",
"iopub.status.busy": "2024-12-14T12:42:08.059502Z",
"iopub.status.idle": "2024-12-14T12:42:08.063597Z",
"shell.execute_reply": "2024-12-14T12:42:08.062947Z"
},
"id": "HO5irBZ3uIzQ"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36000\n",
"[FileInstruction(filename='gs://tensorflow-datasets/datasets/mnist/3.0.1/mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, examples_in_shard=60000)]\n"
]
}
],
"source": [
"print(info.splits['train[15%:75%]'].num_examples)\n",
"print(info.splits['train[15%:75%]'].file_instructions)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZp2XJwQQrI0"
},
"source": [
"## Troubleshooting\n",
"\n",
"### Manual download (if download fails)\n",
"\n",
"If download fails for some reason (e.g. offline,...). You can always manually download the data yourself and place it in the `manual_dir` (defaults to `~/tensorflow_datasets/downloads/manual/`.\n",
"\n",
"To find out which urls to download, look into:\n",
"\n",
" * For new datasets (implemented as folder): [`tensorflow_datasets/`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/)`//checksums.tsv`. For example: [`tensorflow_datasets/datasets/bool_q/checksums.tsv`](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/bool_q/checksums.tsv).\n",
"\n",
" You can find the dataset source location in [our catalog](https://www.tensorflow.org/datasets/catalog/overview).\n",
" * For old datasets: [`tensorflow_datasets/url_checksums/.txt`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/url_checksums)\n",
"\n",
"### Fixing `NonMatchingChecksumError`\n",
"\n",
"TFDS ensure determinism by validating the checksums of downloaded urls.\n",
"If `NonMatchingChecksumError` is raised, might indicate:\n",
"\n",
" * The website may be down (e.g. `503 status code`). Please check the url.\n",
" * For Google Drive URLs, try again later as Drive sometimes rejects downloads when too many people access the same URL. See [bug](https://github.com/tensorflow/datasets/issues/1482)\n",
" * The original datasets files may have been updated. In this case the TFDS dataset builder should be updated. Please open a new Github issue or PR:\n",
" * Register the new checksums with `tfds build --register_checksums`\n",
" * Eventually update the dataset generation code.\n",
" * Update the dataset `VERSION`\n",
" * Update the dataset `RELEASE_NOTES`: What caused the checksums to change ? Did some examples changed ?\n",
" * Make sure the dataset can still be built.\n",
" * Send us a PR\n",
"\n",
"Note: You can also inspect the downloaded file in `~/tensorflow_datasets/download/`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GmeeOokMODg2"
},
"source": [
"## Citation\n",
"\n",
"If you're using `tensorflow-datasets` for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the [dataset catalog](https://www.tensorflow.org/datasets/catalog/overview)).\n",
"\n",
"```\n",
"@misc{TFDS,\n",
" title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},\n",
" howpublished = {\\url{https://www.tensorflow.org/datasets}},\n",
"}\n",
"```"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "tensorflow/datasets",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 0
}