Integration of ZarrDataset with PyTorch’s DataLoader

import zarrdataset as zds

import torch
from torch.utils.data import DataLoader
# These are images from the Image Data Resource (IDR) 
# https://idr.openmicroscopy.org/ that are publicly available and were 
# converted to the OME-NGFF (Zarr) format by the OME group. More examples
# can be found at Public OME-Zarr data (Nov. 2020)
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html

filenames = [
    "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr"
]
import random
import numpy as np

# For reproducibility
np.random.seed(478963)
torch.manual_seed(478964)
random.seed(478965)

Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)

Sample the image randomly

patch_size = dict(Y=1024, X=1024)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)

Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.

Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly

image_specs = zds.ImagesDatasetSpecs(
  filenames=filenames,
  data_group="0",
  source_axes="TCZYX",
)

my_dataset = zds.ZarrDataset(image_specs,
                             patch_sampler=patch_sampler,
                             shuffle=True)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :False)
Modalities: images
Transforms order: []
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.

Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.

import torchvision

img_preprocessing = torchvision.transforms.Compose([
    zds.ToDtype(dtype=np.float32),
])

my_dataset.add_transform("images", img_preprocessing)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :False)
Modalities: images
Transforms order: [('images',)]
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.

Create a DataLoader from the dataset object

ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.

my_dataloader = DataLoader(my_dataset, num_workers=0)
samples = []
for i, sample in enumerate(my_dataloader):
    # Samples generated by DataLoaders have Batch (B) as first axes
    samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))

    print(f"Sample {i+1} with size {sample.shape}")

    if i >= 4:
        # Take only five samples for illustration purposes
        break

samples = np.hstack(samples)
Sample 1 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 2 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 3 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 4 with size torch.Size([1, 1, 3, 1, 1024, 1024])
import matplotlib.pyplot as plt

plt.imshow(samples / 255.0)
plt.show()
../_images/5e5194f3aacd660f1d602d50e46aff1b64a2b0f19f6be64ea92299b4973e7d97.png

Multithread data loading with Torch’s DataLoader

This example will use multiple workers to load patches of size 256x256 from the same image

patch_size = dict(Y=256, X=256)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)

Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.

Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly

image_specs = zds.ImagesDatasetSpecs(
  filenames=filenames,
  data_group="3",
  source_axes="TCZYX",
)

my_dataset = zds.ZarrDataset(image_specs,
                             patch_sampler=patch_sampler,
                             shuffle=True)

ZarrDataset performs some special operations for enabling multithread data loading without replicating the full dataset on each worker.

For this reason, ZarrDataset requires its own worker_init_fn initialization function: zarrdataset_worker_init_fn.

my_dataloader = DataLoader(my_dataset, num_workers=4,
                           worker_init_fn=zds.zarrdataset_worker_init_fn)

Now the data can be safely loaded using multiple workers.

samples = []
for i, sample in enumerate(my_dataloader):
    # Samples generated by DataLoaders have Batch (B) as first axes
    samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))

    print(f"Sample {i+1} with size {sample.shape}")

    if i >= 4:
        # Take only five samples for illustration purposes
        break

samples = np.hstack(samples)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[15], line 2
      1 samples = []
----> 2 for i, sample in enumerate(my_dataloader):
      3     # Samples generated by DataLoaders have Batch (B) as first axes
      4     samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
      6     print(f"Sample {i+1} with size {sample.shape}")

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1327, in _MultiProcessingDataLoaderIter._next_data(self)
   1324     return self._process_data(data)
   1326 assert not self._shutdown and self._tasks_outstanding > 0
-> 1327 idx, data = self._get_data()
   1328 self._tasks_outstanding -= 1
   1329 if self._dataset_kind == _DatasetKind.Iterable:
   1330     # Check for _IterableDatasetStopIteration

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1293, in _MultiProcessingDataLoaderIter._get_data(self)
   1289     # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
   1290     # need to call `.task_done()` because we don't use `.join()`.
   1291 else:
   1292     while True:
-> 1293         success, data = self._try_get_data()
   1294         if success:
   1295             return data

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1131, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
   1118 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
   1119     # Tries to fetch data from `self._data_queue` once for a given timeout.
   1120     # This can also be used as inner loop of fetching without timeout, with
   (...)
   1128     # Returns a 2-tuple:
   1129     #   (bool: whether successfully get data, any: data if successful else None)
   1130     try:
-> 1131         data = self._data_queue.get(timeout=timeout)
   1132         return (True, data)
   1133     except Exception as e:
   1134         # At timeout and error, we manually check whether any worker has
   1135         # failed. Note that this is the only mechanism for Windows to detect
   1136         # worker failures.

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/queues.py:113, in Queue.get(self, block, timeout)
    111 if block:
    112     timeout = deadline - time.monotonic()
--> 113     if not self._poll(timeout):
    114         raise Empty
    115 elif not self._poll():

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:257, in _ConnectionBase.poll(self, timeout)
    255 self._check_closed()
    256 self._check_readable()
--> 257 return self._poll(timeout)

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:424, in Connection._poll(self, timeout)
    423 def _poll(self, timeout):
--> 424     r = wait([self], timeout)
    425     return bool(r)

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:931, in wait(object_list, timeout)
    928     deadline = time.monotonic() + timeout
    930 while True:
--> 931     ready = selector.select(timeout)
    932     if ready:
    933         return [key.fileobj for (key, events) in ready]

File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/selectors.py:416, in _PollLikeSelector.select(self, timeout)
    414 ready = []
    415 try:
--> 416     fd_event_list = self._selector.poll(timeout)
    417 except InterruptedError:
    418     return ready

KeyboardInterrupt: 
plt.imshow(samples)
plt.show()