{ "cells": [ { "cell_type": "markdown", "id": "6193287e", "metadata": {}, "source": [ "# Custom masks for sampling specific regions from images with ZarrDataset" ] }, { "cell_type": "code", "execution_count": null, "id": "0f4a6713", "metadata": {}, "outputs": [], "source": [ "import zarrdataset as zds\n", "import zarr" ] }, { "cell_type": "code", "execution_count": null, "id": "02ec8f70", "metadata": {}, "outputs": [], "source": [ "# These are images from the Image Data Resource (IDR) \n", "# https://idr.openmicroscopy.org/ that are publicly available and were \n", "# converted to the OME-NGFF (Zarr) format by the OME group. More examples\n", "# can be found at Public OME-Zarr data (Nov. 2020)\n", "# https://www.openmicroscopy.org/2020/11/04/zarr-data.html\n", "\n", "filenames = [\"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "69c4aeb4", "metadata": {}, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "\n", "# For reproducibility\n", "np.random.seed(478963)\n", "random.seed(478965)" ] }, { "cell_type": "code", "execution_count": null, "id": "43c67512", "metadata": {}, "outputs": [], "source": [ "z_img = zarr.open(filenames[0], mode=\"r\")\n", "z_img[\"0\"].info" ] }, { "cell_type": "code", "execution_count": null, "id": "fc96ed90", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "plt.imshow(np.moveaxis(z_img[\"4\"][0, :, 0], 0, -1))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "fe865601", "metadata": {}, "source": [ "Define a mask from where patches can be extracted" ] }, { "cell_type": "code", "execution_count": null, "id": "46e58845", "metadata": {}, "outputs": [], "source": [ "from skimage import color, filters, morphology" ] }, { "cell_type": "code", "execution_count": null, "id": "aaca8473", "metadata": {}, "outputs": [], "source": [ "im_gray = color.rgb2gray(z_img[\"4\"][0, :, 0], channel_axis=0)\n", "thresh = filters.threshold_otsu(im_gray)\n", "\n", "mask = im_gray > thresh\n", "mask = morphology.remove_small_objects(mask == 0, min_size=16 ** 2,\n", " connectivity=2)\n", "mask = morphology.remove_small_holes(mask, area_threshold=128)\n", "mask = morphology.binary_erosion(mask, morphology.disk(8))\n", "mask = morphology.binary_dilation(mask, morphology.disk(8))" ] }, { "cell_type": "code", "execution_count": null, "id": "430549af", "metadata": {}, "outputs": [], "source": [ "plt.imshow(mask)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "ca708846", "metadata": {}, "outputs": [], "source": [ "plt.imshow(np.moveaxis(z_img[\"4\"][0, :, 0], 0, -1))\n", "plt.imshow(mask, cmap=\"gray\", alpha=1.0*(mask < 1))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ceb754da", "metadata": {}, "source": [ "## Extract patches of size 512x512 pixels from a Whole Slide Image (WSI)" ] }, { "cell_type": "markdown", "id": "c1f558c8", "metadata": {}, "source": [ "Sample the image uniformly in a squared grid pattern" ] }, { "cell_type": "code", "execution_count": null, "id": "874ad5dd", "metadata": {}, "outputs": [], "source": [ "patch_size = dict(Y=512, X=512)\n", "patch_sampler = zds.PatchSampler(patch_size=patch_size)" ] }, { "cell_type": "markdown", "id": "b1133e72", "metadata": {}, "source": [ "Use the ZarrDataset class to enable extraction of samples from masked regions.\n", "\n", "An extra dimension is added to the mask, so it matches the number of spatial axes in the image" ] }, { "cell_type": "code", "execution_count": null, "id": "3aeaafe0", "metadata": {}, "outputs": [], "source": [ "image_specs = zds.ImagesDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"1\",\n", " source_axes=\"TCZYX\",\n", ")\n", "\n", "# Use the MasksDatasetSpecs to add the specifications of the masks.\n", "#masks_specs = zds.MasksDatasetSpecs(\n", "masks_specs = zds.LabelsDatasetSpecs(\n", " filenames=[mask],\n", " source_axes=\"YX\",\n", " axes=\"ZYX\",\n", " modality=\"masks\",\n", ")\n", "\n", "my_dataset = zds.ZarrDataset([image_specs, masks_specs],\n", " patch_sampler=patch_sampler)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9276695", "metadata": {}, "outputs": [], "source": [ "ds_iterator = iter(my_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "95973947", "metadata": {}, "outputs": [], "source": [ "sample = next(ds_iterator)\n", "type(sample[0]), sample[0].shape, sample[0].dtype\n", "type(sample[1]), sample[1].shape, sample[1].dtype" ] }, { "cell_type": "code", "execution_count": null, "id": "bf88917a", "metadata": {}, "outputs": [], "source": [ "plt.imshow(np.moveaxis(sample[0][0, :, 0], 0, -1))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "bfeb0a04", "metadata": {}, "outputs": [], "source": [ "plt.imshow(sample[1][0])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "448656be", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "labels = []\n", "for i, sample in enumerate(my_dataset):\n", " samples.append(np.moveaxis(sample[0][0, :, 0], 0, -1))\n", " labels.append(sample[1][0])\n", "\n", " if i >= 4:\n", " # Take only five samples for illustration purposes\n", " break\n", "\n", "samples = np.hstack(samples)\n", "labels = np.hstack(labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "89cefb01", "metadata": {}, "outputs": [], "source": [ "plt.imshow(samples)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "d20af5bc", "metadata": {}, "outputs": [], "source": [ "plt.imshow(labels)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c07aa2b2", "metadata": {}, "source": [ "## Use a function to generate the masks for each image in the dataset\n", "\n", "Get only patches that are covered by at least 1/16th of their area by the mask" ] }, { "cell_type": "code", "execution_count": null, "id": "f03aabc0", "metadata": {}, "outputs": [], "source": [ "patch_size = dict(Y=512, X=512)\n", "patch_sampler = zds.PatchSampler(patch_size=patch_size, min_area=1/16)" ] }, { "cell_type": "markdown", "id": "d9ef53eb", "metadata": {}, "source": [ "Apply WSITissueMaskGenerator transform to each image in the dataset to define each sampling mask" ] }, { "cell_type": "code", "execution_count": null, "id": "9e6867a5", "metadata": {}, "outputs": [], "source": [ "mask_func = zds.WSITissueMaskGenerator(mask_scale=1,\n", " min_size=16,\n", " area_threshold=128,\n", " axes=\"ZYX\")" ] }, { "cell_type": "markdown", "id": "14b9edae", "metadata": {}, "source": [ "Because the input image (zarr group \"1\") is large, computing the mask directly on that could require high computational resources.\n", "\n", "For that reason, use a donwsampled version of that image instead by pointing `mask_data_group=\"4\"` to use a 1:16 downsampled version of the input image.\n", "\n", "The `mask_axes` should match the ones that WSITissueMaskGenerator requies as input (\"YXC\"). To do that, a ROI can be specified to take just the spatial and channel axes from the input image with `mask_roi=\"(0,0,0,0,0):(1,-1,1,-1,-1)\"`, and rearrange the output axes with `mask_axes=\"YXC\"`." ] }, { "cell_type": "code", "execution_count": null, "id": "bbb7544e", "metadata": {}, "outputs": [], "source": [ "image_specs = zds.ImagesDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"1\",\n", " source_axes=\"TCZYX\",\n", ")\n", "\n", "# Use the MasksDatasetSpecs to add the specifications of the masks.\n", "# The mask generation function is added as `image_loader_func` parameter of the dataset specification for masks.\n", "masks_specs = zds.MasksDatasetSpecs(\n", " filenames=filenames,\n", " data_group=\"4\",\n", " source_axes=\"TCZYX\",\n", " axes=\"YXC\",\n", " roi=\"(0,0,0,0,0):(1,-1,1,-1,-1)\",\n", " image_loader_func=mask_func,\n", ")\n", "\n", "my_dataset = zds.ZarrDataset([image_specs, masks_specs],\n", " patch_sampler=patch_sampler,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "d50fd9a3", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "\n", "for i, sample in enumerate(my_dataset):\n", " samples.append(np.moveaxis(sample[0, :, 0], 0, -1))\n", "\n", " if i >= 4:\n", " # Take only five samples for illustration purposes\n", " break\n", "\n", "samples = np.hstack(samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "059f72a6", "metadata": {}, "outputs": [], "source": [ "plt.imshow(samples)\n", "plt.show()" ] } ], "metadata": { "execution": { "timeout": 600 }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }