{ "cells": [ { "cell_type": "markdown", "id": "34e29748", "metadata": {}, "source": [ "# Integration of ZarrDataset with PyTorch's DataLoader (Advanced)" ] }, { "cell_type": "code", "execution_count": null, "id": "a0873b31", "metadata": {}, "outputs": [], "source": [ "import zarrdataset as zds\n", "\n", "import torch\n", "from torch.utils.data import DataLoader, ChainDataset" ] }, { "cell_type": "code", "execution_count": null, "id": "fa485c1c", "metadata": {}, "outputs": [], "source": [ "# These are images from the Image Data Resource (IDR) \n", "# https://idr.openmicroscopy.org/ that are publicly available and were \n", "# converted to the OME-NGFF (Zarr) format by the OME group. More examples\n", "# can be found at Public OME-Zarr data (Nov. 2020)\n", "# https://www.openmicroscopy.org/2020/11/04/zarr-data.html\n", "\n", "filenames = [\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.1/6001240.zarr\",\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.1/6001241.zarr\",\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.1/6001242.zarr\",\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.1/6001243.zarr\",\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "fab3089e", "metadata": {}, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "\n", "# For reproducibility\n", "np.random.seed(478963)\n", "torch.manual_seed(478964)\n", "random.seed(478965)" ] }, { "cell_type": "markdown", "id": "4b6aa9b3", "metadata": {}, "source": [ "### Extracting patches of size 128x128x32 voxels from a three-dimensional image" ] }, { "cell_type": "markdown", "id": "c42a65b0", "metadata": {}, "source": [ "Sample the image randomly" ] }, { "cell_type": "code", "execution_count": null, "id": "365fd8df", "metadata": {}, "outputs": [], "source": [ "patch_size = dict(Z=32, Y=128, X=128)\n", "patch_sampler = zds.PatchSampler(patch_size=patch_size)" ] }, { "cell_type": "markdown", "id": "cf5c1089", "metadata": {}, "source": [ "Transform the input data from uint16 to float16 with a torchvision pre-processing pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "e06ef22f", "metadata": {}, "outputs": [], "source": [ "import torchvision\n", "\n", "img_preprocessing = torchvision.transforms.Compose([\n", " zds.ToDtype(dtype=np.float16)\n", "])" ] }, { "cell_type": "markdown", "id": "486b006b", "metadata": {}, "source": [ "Pass the pre-processing function to ZarrDataset to be used when generating the samples.\n", "\n", "Also, enable return of each patch positions, and the worker ID that generated each patch." ] }, { "cell_type": "code", "execution_count": null, "id": "681ab078", "metadata": {}, "outputs": [], "source": [ "my_datasets = [\n", " zds.ZarrDataset(\n", " [\n", " zds.ImagesDatasetSpecs(\n", " filenames=fn,\n", " data_group=\"0\",\n", " source_axes=\"TCZYX\",\n", " transform=img_preprocessing,\n", " )\n", " ],\n", " patch_sampler=patch_sampler,\n", " shuffle=True,\n", " return_positions=True,\n", " return_worker_id=True\n", " )\n", " for fn in filenames\n", "]" ] }, { "cell_type": "markdown", "id": "0a98874c", "metadata": {}, "source": [ "### Create a ChainDataset from a set of ZarrDatasets that can be put together a single large dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "164e9457", "metadata": {}, "outputs": [], "source": [ "my_chain_dataset = ChainDataset(my_datasets)" ] }, { "cell_type": "markdown", "id": "a85d5f69", "metadata": {}, "source": [ "Make sure the chained_zarrdataset_worker_init_fn function is passed to the DataLoader, so the workers can initialize the dataset correctly" ] }, { "cell_type": "code", "execution_count": null, "id": "da6637fc", "metadata": {}, "outputs": [], "source": [ "my_dataloader = DataLoader(my_chain_dataset,\n", " num_workers=4,\n", " worker_init_fn=zds.chained_zarrdataset_worker_init_fn,\n", " batch_size=2\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "84d085c6", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "positions = []\n", "wids = []\n", "for i, (wid, pos, sample) in enumerate(my_dataloader):\n", " wids += [w for w in wid]\n", " positions += [p for p in pos]\n", " samples += [s for s in sample] \n", "\n", " print(f\"Sample {i+1} with size {sample.shape} extracted by worker {wid}.\")\n", "\n", " if i >= 4:\n", " # Take five batches for illustration purposes\n", " break\n", "\n", "samples = torch.cat(samples, dim=0)\n", "\n", "samples.shape" ] }, { "cell_type": "markdown", "id": "2f706067", "metadata": {}, "source": [ "### Generate a grid with the sampled patches using `torchvision` utilities" ] }, { "cell_type": "code", "execution_count": null, "id": "bc2f7363", "metadata": {}, "outputs": [], "source": [ "samples_grid = torchvision.utils.make_grid(samples[:, :, 16, :, :])" ] }, { "cell_type": "code", "execution_count": null, "id": "50d7bc83", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.imshow(samples_grid[0], cmap=\"gray\")\n", "plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "2cfab3b7", "metadata": {}, "outputs": [], "source": [ "plt.imshow(samples_grid[1], cmap=\"gray\")\n", "plt.axis('off')" ] } ], "metadata": { "execution": { "timeout": 600 }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }