Integration of ZarrDataset with PyTorch’s DataLoader
import zarrdataset as zds
import torch
from torch.utils.data import DataLoader
# These are images from the Image Data Resource (IDR)
# https://idr.openmicroscopy.org/ that are publicly available and were
# converted to the OME-NGFF (Zarr) format by the OME group. More examples
# can be found at Public OME-Zarr data (Nov. 2020)
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html
filenames = [
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr"
]
import random
import numpy as np
# For reproducibility
np.random.seed(478963)
torch.manual_seed(478964)
random.seed(478965)
Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)
Sample the image randomly
patch_size = dict(Y=1024, X=1024)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)
Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="0",
source_axes="TCZYX",
)
my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
shuffle=True)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :False)
Modalities: images
Transforms order: []
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.
Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.
import torchvision
img_preprocessing = torchvision.transforms.Compose([
zds.ToDtype(dtype=np.float32),
])
my_dataset.add_transform("images", img_preprocessing)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :False)
Modalities: images
Transforms order: [('images',)]
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.
Create a DataLoader from the dataset object
ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.
my_dataloader = DataLoader(my_dataset, num_workers=0)
samples = []
for i, sample in enumerate(my_dataloader):
# Samples generated by DataLoaders have Batch (B) as first axes
samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
print(f"Sample {i+1} with size {sample.shape}")
if i >= 4:
# Take only five samples for illustration purposes
break
samples = np.hstack(samples)
Sample 1 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 2 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 3 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 4 with size torch.Size([1, 1, 3, 1, 1024, 1024])
import matplotlib.pyplot as plt
plt.imshow(samples / 255.0)
plt.show()
![../_images/5e5194f3aacd660f1d602d50e46aff1b64a2b0f19f6be64ea92299b4973e7d97.png](../_images/5e5194f3aacd660f1d602d50e46aff1b64a2b0f19f6be64ea92299b4973e7d97.png)
Multithread data loading with Torch’s DataLoader
This example will use multiple workers to load patches of size 256x256 from the same image
patch_size = dict(Y=256, X=256)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)
Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="3",
source_axes="TCZYX",
)
my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
shuffle=True)
ZarrDataset performs some special operations for enabling multithread data loading without replicating the full dataset on each worker.
For this reason, ZarrDataset requires its own worker_init_fn
initialization function: zarrdataset_worker_init_fn
.
my_dataloader = DataLoader(my_dataset, num_workers=4,
worker_init_fn=zds.zarrdataset_worker_init_fn)
Now the data can be safely loaded using multiple workers.
samples = []
for i, sample in enumerate(my_dataloader):
# Samples generated by DataLoaders have Batch (B) as first axes
samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
print(f"Sample {i+1} with size {sample.shape}")
if i >= 4:
# Take only five samples for illustration purposes
break
samples = np.hstack(samples)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[15], line 2
1 samples = []
----> 2 for i, sample in enumerate(my_dataloader):
3 # Samples generated by DataLoaders have Batch (B) as first axes
4 samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
6 print(f"Sample {i+1} with size {sample.shape}")
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1327, in _MultiProcessingDataLoaderIter._next_data(self)
1324 return self._process_data(data)
1326 assert not self._shutdown and self._tasks_outstanding > 0
-> 1327 idx, data = self._get_data()
1328 self._tasks_outstanding -= 1
1329 if self._dataset_kind == _DatasetKind.Iterable:
1330 # Check for _IterableDatasetStopIteration
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1293, in _MultiProcessingDataLoaderIter._get_data(self)
1289 # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1290 # need to call `.task_done()` because we don't use `.join()`.
1291 else:
1292 while True:
-> 1293 success, data = self._try_get_data()
1294 if success:
1295 return data
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1131, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
1118 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1119 # Tries to fetch data from `self._data_queue` once for a given timeout.
1120 # This can also be used as inner loop of fetching without timeout, with
(...)
1128 # Returns a 2-tuple:
1129 # (bool: whether successfully get data, any: data if successful else None)
1130 try:
-> 1131 data = self._data_queue.get(timeout=timeout)
1132 return (True, data)
1133 except Exception as e:
1134 # At timeout and error, we manually check whether any worker has
1135 # failed. Note that this is the only mechanism for Windows to detect
1136 # worker failures.
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/queues.py:113, in Queue.get(self, block, timeout)
111 if block:
112 timeout = deadline - time.monotonic()
--> 113 if not self._poll(timeout):
114 raise Empty
115 elif not self._poll():
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:257, in _ConnectionBase.poll(self, timeout)
255 self._check_closed()
256 self._check_readable()
--> 257 return self._poll(timeout)
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:424, in Connection._poll(self, timeout)
423 def _poll(self, timeout):
--> 424 r = wait([self], timeout)
425 return bool(r)
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/multiprocessing/connection.py:931, in wait(object_list, timeout)
928 deadline = time.monotonic() + timeout
930 while True:
--> 931 ready = selector.select(timeout)
932 if ready:
933 return [key.fileobj for (key, events) in ready]
File ~/miniforge-pypy3/envs/zarrdataset-docs/lib/python3.9/selectors.py:416, in _PollLikeSelector.select(self, timeout)
414 ready = []
415 try:
--> 416 fd_event_list = self._selector.poll(timeout)
417 except InterruptedError:
418 return ready
KeyboardInterrupt:
plt.imshow(samples)
plt.show()