{ "cells": [ { "cell_type": "markdown", "id": "e996e157", "metadata": {}, "source": [ "# Integration of ZarrDataset with PyTorch's DataLoader for inference (Advanced)\n", "\n", "```python\n", "import zarrdataset as zds\n", "\n", "import torch\n", "from torch.utils.data import DataLoader\n", "```\n", "\n", "\n", "```python\n", "# These are images from the Image Data Resource (IDR) \n", "# https://idr.openmicroscopy.org/ that are publicly available and were \n", "# converted to the OME-NGFF (Zarr) format by the OME group. More examples\n", "# can be found at Public OME-Zarr data (Nov. 2020)\n", "# https://www.openmicroscopy.org/2020/11/04/zarr-data.html\n", "\n", "filenames = [\n", " \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr\"\n", "]\n", "```\n", "\n", "\n", "```python\n", "import random\n", "import numpy as np\n", "\n", "# For reproducibility\n", "np.random.seed(478963)\n", "torch.manual_seed(478964)\n", "random.seed(478965)\n", "```\n", "\n", "## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)\n", "\n", "Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result.\n", "Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size by setting `allow_incomplete_patches=True`.\n", "\n", "\n", "```python\n", "patch_size = dict(Y=128, X=128)\n", "pad = dict(Y=16, X=16)\n", "patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True)\n", "```\n", "\n", "Create a dataset from the list of filenames. All those files should be stored within their respective group \"0\".\n", "\n", "Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly\n", "\n", "\n", "```python\n", "image_specs = zds.ImagesDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"4\",\n", " source_axes=\"TCZYX\",\n", " axes=\"YXC\",\n", " roi=\"0,0,0,0,0:1,-1,1,-1,-1\"\n", ")\n", "\n", "my_dataset = zds.ZarrDataset(image_specs,\n", " patch_sampler=patch_sampler,\n", " return_positions=True)\n", "```\n", "\n", "\n", "```python\n", "my_dataset\n", "```\n", "\n", "\n", "\n", "\n", " ZarrDataset (PyTorch support:True, tqdm support :True)\n", " Modalities: images\n", " Transforms order: []\n", " Using images modality as reference.\n", " Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.\n", "\n", "\n", "\n", "Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.\n", "\n", "\n", "```python\n", "import torchvision\n", "\n", "img_preprocessing = torchvision.transforms.Compose([\n", " zds.ToDtype(dtype=np.float32),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(127, 255)\n", "])\n", "\n", "my_dataset.add_transform(\"images\", img_preprocessing)\n", "```\n", "\n", "\n", "```python\n", "my_dataset\n", "```\n", "\n", "\n", "\n", "\n", " ZarrDataset (PyTorch support:True, tqdm support :True)\n", " Modalities: images\n", " Transforms order: [('images',)]\n", " Using images modality as reference.\n", " Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.\n", "\n", "\n", "\n", "## Create a DataLoader from the dataset object\n", "\n", "ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.\n", "\n", "\n", "```python\n", "my_dataloader = DataLoader(my_dataset, num_workers=0)\n", "```\n", "\n", "\n", "```python\n", "import dask.array as da\n", "import numpy as np\n", "import zarr\n", "\n", "z_arr = zarr.open(\"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4\", mode=\"r\")\n", "\n", "H = z_arr.shape[-2]\n", "W = z_arr.shape[-1]\n", "\n", "pad_H = (128 - H) % 128\n", "pad_W = (128 - W) % 128\n", "z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128))\n", "z_prediction\n", "```\n", "\n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "Set up a simple model for illustration purpose\n", "\n", "\n", "```python\n", "model = torch.nn.Sequential(\n", " torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1),\n", " torch.nn.ReLU()\n", ")\n", "```\n", "\n", "\n", "```python\n", "for i, (pos, sample) in enumerate(my_dataloader):\n", " pred_pos = (\n", " slice(pos[0, 0, 0].item() + 16,\n", " pos[0, 0, 1].item() - 16),\n", " slice(pos[0, 1, 0].item() + 16,\n", " pos[0, 1, 1].item() - 16)\n", " )\n", " pred = model(sample)\n", " z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16]\n", "```\n", "\n", "## Visualize the result\n", "\n", "\n", "```python\n", "import matplotlib.pyplot as plt\n", "\n", "plt.subplot(2, 1, 1)\n", "plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1))\n", "plt.subplot(2, 1, 2)\n", "plt.imshow(z_prediction)\n", "plt.show()\n", "```" ] } ], "metadata": { "execution": { "timeout": 600 }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }