{ "cells": [ { "cell_type": "markdown", "id": "13fcd9c9", "metadata": {}, "source": [ "# Integration of ZarrDataset with PyTorch's DataLoader" ] }, { "cell_type": "code", "execution_count": null, "id": "0a27ba15", "metadata": {}, "outputs": [], "source": [ "import zarrdataset as zds\n", "\n", "import torch\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", "execution_count": null, "id": "33b4ef7a", "metadata": {}, "outputs": [], "source": [ "# These are images from the Image Data Resource (IDR) \n", "# https://idr.openmicroscopy.org/ that are publicly available and were \n", "# converted to the OME-NGFF (Zarr) format by the OME group. More examples\n", "# can be found at Public OME-Zarr data (Nov. 2020)\n", "# https://www.openmicroscopy.org/2020/11/04/zarr-data.html\n", "\n", "filenames = [\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr\"\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "5a2d0543", "metadata": {}, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "\n", "# For reproducibility\n", "np.random.seed(478963)\n", "torch.manual_seed(478964)\n", "random.seed(478965)" ] }, { "cell_type": "markdown", "id": "179c38c5", "metadata": {}, "source": [ "## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)" ] }, { "cell_type": "markdown", "id": "65ef4459", "metadata": {}, "source": [ "Sample the image randomly" ] }, { "cell_type": "code", "execution_count": null, "id": "9604a906", "metadata": {}, "outputs": [], "source": [ "patch_size = dict(Y=1024, X=1024)\n", "patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)" ] }, { "cell_type": "markdown", "id": "270e5e74", "metadata": {}, "source": [ "Create a dataset from the list of filenames. All those files should be stored within their respective group \"0\".\n", "\n", "Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly" ] }, { "cell_type": "code", "execution_count": null, "id": "dccf0e2b", "metadata": {}, "outputs": [], "source": [ "image_specs = zds.ImagesDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"0\",\n", " source_axes=\"TCZYX\",\n", ")\n", "\n", "my_dataset = zds.ZarrDataset(image_specs,\n", " patch_sampler=patch_sampler,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "af770aaf", "metadata": {}, "outputs": [], "source": [ "my_dataset" ] }, { "cell_type": "markdown", "id": "86cdc3ef", "metadata": {}, "source": [ "Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32." ] }, { "cell_type": "code", "execution_count": null, "id": "bc222884", "metadata": {}, "outputs": [], "source": [ "import torchvision\n", "\n", "img_preprocessing = torchvision.transforms.Compose([\n", " zds.ToDtype(dtype=np.float32),\n", "])\n", "\n", "my_dataset.add_transform(\"images\", img_preprocessing)" ] }, { "cell_type": "code", "execution_count": null, "id": "fbba1ca4", "metadata": {}, "outputs": [], "source": [ "my_dataset" ] }, { "cell_type": "markdown", "id": "398848d6", "metadata": {}, "source": [ "## Create a DataLoader from the dataset object" ] }, { "cell_type": "markdown", "id": "1f1379a4", "metadata": {}, "source": [ "ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module." ] }, { "cell_type": "code", "execution_count": null, "id": "1322f0d6", "metadata": {}, "outputs": [], "source": [ "my_dataloader = DataLoader(my_dataset, num_workers=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "98187912", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "for i, sample in enumerate(my_dataloader):\n", " # Samples generated by DataLoaders have Batch (B) as first axes\n", " samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))\n", "\n", " print(f\"Sample {i+1} with size {sample.shape}\")\n", "\n", " if i >= 4:\n", " # Take only five samples for illustration purposes\n", " break\n", "\n", "samples = np.hstack(samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "36650cd2", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.imshow(samples / 255.0)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c6e0bcc9", "metadata": {}, "source": [ "## Multithread data loading with Torch's DataLoader" ] }, { "cell_type": "markdown", "id": "e3e53816", "metadata": {}, "source": [ "This example will use multiple workers to load patches of size 256x256 from the same image" ] }, { "cell_type": "code", "execution_count": null, "id": "01a7ec2b", "metadata": {}, "outputs": [], "source": [ "patch_size = dict(Y=256, X=256)\n", "patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)" ] }, { "cell_type": "markdown", "id": "1f2b4694", "metadata": {}, "source": [ "Create a dataset from the list of filenames. All those files should be stored within their respective group \"0\".\n", "\n", "Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly" ] }, { "cell_type": "code", "execution_count": null, "id": "4d59b9ae", "metadata": {}, "outputs": [], "source": [ "image_specs = zds.ImagesDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"3\",\n", " source_axes=\"TCZYX\",\n", ")\n", "\n", "my_dataset = zds.ZarrDataset(image_specs,\n", " patch_sampler=patch_sampler,\n", " shuffle=True)" ] }, { "cell_type": "markdown", "id": "81cf1be3", "metadata": {}, "source": [ "ZarrDataset performs some special operations for enabling multithread data loading without replicating the full dataset on each worker.\n", "\n", "For this reason, ZarrDataset requires its own `worker_init_fn` initialization function: `zarrdataset_worker_init_fn`." ] }, { "cell_type": "code", "execution_count": null, "id": "0a371fdf", "metadata": {}, "outputs": [], "source": [ "my_dataloader = DataLoader(my_dataset, num_workers=4,\n", " worker_init_fn=zds.zarrdataset_worker_init_fn)" ] }, { "cell_type": "markdown", "id": "5c62ab89", "metadata": {}, "source": [ "Now the data can be safely loaded using multiple workers." ] }, { "cell_type": "code", "execution_count": null, "id": "07954b39", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "for i, sample in enumerate(my_dataloader):\n", " # Samples generated by DataLoaders have Batch (B) as first axes\n", " samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))\n", "\n", " print(f\"Sample {i+1} with size {sample.shape}\")\n", "\n", " if i >= 4:\n", " # Take only five samples for illustration purposes\n", " break\n", "\n", "samples = np.hstack(samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "a3a14273", "metadata": {}, "outputs": [], "source": [ "plt.imshow(samples)\n", "plt.show()" ] } ], "metadata": { "execution": { "timeout": 600 }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }