Integration of ZarrDataset with PyTorch’s DataLoader
import zarrdataset as zds
import torch
from torch.utils.data import DataLoader
# These are images from the Image Data Resource (IDR)
# https://idr.openmicroscopy.org/ that are publicly available and were
# converted to the OME-NGFF (Zarr) format by the OME group. More examples
# can be found at Public OME-Zarr data (Nov. 2020)
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html
filenames = [
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr"
]
import random
import numpy as np
# For reproducibility
np.random.seed(478963)
torch.manual_seed(478964)
random.seed(478965)
Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)
Sample the image randomly
patch_size = dict(Y=1024, X=1024)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)
Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="0",
source_axes="TCZYX",
)
my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
shuffle=True)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: []
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.
Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.
import torchvision
img_preprocessing = torchvision.transforms.Compose([
zds.ToDtype(dtype=np.float32),
])
my_dataset.add_transform("images", img_preprocessing)
my_dataset
ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: [('images',)]
Using images modality as reference.
Using <class 'zarrdataset._samplers.BlueNoisePatchSampler'> for sampling patches of size {'Z': 1, 'Y': 1024, 'X': 1024}.
Create a DataLoader from the dataset object
ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.
my_dataloader = DataLoader(my_dataset, num_workers=0)
samples = []
for i, sample in enumerate(my_dataloader):
# Samples generated by DataLoaders have Batch (B) as first axes
samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
print(f"Sample {i+1} with size {sample.shape}")
if i >= 4:
# Take only five samples for illustration purposes
break
samples = np.hstack(samples)
Sample 1 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 2 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 3 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 4 with size torch.Size([1, 1, 3, 1, 1024, 1024])
Sample 5 with size torch.Size([1, 1, 3, 1, 1024, 1024])
import matplotlib.pyplot as plt
plt.imshow(samples / 255.0)
plt.show()
![../_images/e880a407e638d92dfbb3ecddfb7e77f00928d4513ef46732aad1a34f268aaba6.png](../_images/e880a407e638d92dfbb3ecddfb7e77f00928d4513ef46732aad1a34f268aaba6.png)
Multithread data loading with Torch’s DataLoader
This example will use multiple workers to load patches of size 256x256 from the same image
patch_size = dict(Y=256, X=256)
patch_sampler = zds.BlueNoisePatchSampler(patch_size=patch_size)
Create a dataset from the list of filenames. All those files should be stored within their respective group “0”.
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="3",
source_axes="TCZYX",
)
my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
shuffle=True)
ZarrDataset performs some special operations for enabling multithread data loading without replicating the full dataset on each worker.
For this reason, ZarrDataset requires its own worker_init_fn
initialization function: zarrdataset_worker_init_fn
.
my_dataloader = DataLoader(my_dataset, num_workers=4,
worker_init_fn=zds.zarrdataset_worker_init_fn)
Now the data can be safely loaded using multiple workers.
samples = []
for i, sample in enumerate(my_dataloader):
# Samples generated by DataLoaders have Batch (B) as first axes
samples.append(np.moveaxis(sample[0, 0, :, 0].numpy(), 0, -1))
print(f"Sample {i+1} with size {sample.shape}")
if i >= 4:
# Take only five samples for illustration purposes
break
samples = np.hstack(samples)
Sample 1 with size torch.Size([1, 1, 3, 1, 256, 256])
Sample 2 with size torch.Size([1, 1, 3, 1, 256, 256])
Sample 3 with size torch.Size([1, 1, 3, 1, 256, 256])
Sample 4 with size torch.Size([1, 1, 3, 1, 256, 256])
Sample 5 with size torch.Size([1, 1, 3, 1, 256, 256])
plt.imshow(samples)
plt.show()
![../_images/436ea478c2ee0d47a17897babbdadb6f03141f6ce5da68079b1003253690ac07.png](../_images/436ea478c2ee0d47a17897babbdadb6f03141f6ce5da68079b1003253690ac07.png)