Advanced Machine Learning with Python (Session 2 - Part 2)

Fernando Cervantes (fernando.cervantes@jax.org)

Tissue classification with the MoNuSAC dataset

Tissue classification with the MoNuSAC dataset

  • R. Verma, et al. “MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.” IEEE Transactions on Medical Imaging (2021)

Open notebook in Colab View solutions

Data preparation

Note

More information about the type of tissue of each image can be found here.

Tissue classification with the MoNuSAC dataset

from skimage.io import imread
import matplotlib.pyplot as plt

img = imread(train_images_fns[0])
plt.imshow(img)
plt.title(tissue_classes[train_labels[0]])
plt.show()

img = imread(test_images_fns[0])
plt.imshow(img)
plt.title(tissue_classes[test_labels[0]])
plt.show()

Tissue classification with the MoNuSAC dataset

import torchvision
from torchvision.transforms.v2 import Compose, ToTensor

inception_weights = torchvision.models.inception.Inception_V3_Weights.IMAGENET1K_V1

pipeline = Compose([
  ToTensor(),
  inception_weights.transforms()
])

pipeline

Tissue classification with the MoNuSAC dataset

from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, image_filenames, image_labels, transform=None):
        self.image_filenames = image_filenames
        self.image_labels = image_labels
        self.transform = transform

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        image = imread(self.image_filenames[idx])
        if self.transform is not None:
            image = self.transform(image)

        label = self.image_labels[idx]

        return image, label

Tissue classification with the MoNuSAC dataset

from torch.utils.data import random_split

train_ds = CustomImageDataset(train_images_fns, train_labels, pipeline)
test_ds = CustomImageDataset(test_images_fns, test_labels, pipeline)

train_ds, val_ds = random_split(train_ds, [0.8, 0.2])

print(f"Training images={len(train_ds)}")
print(f"Validation images={len(val_ds)}")
print(f"Test images={len(test_ds)}")
Training images=168
Validation images=41
Test images=101
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=128)
test_dl = DataLoader(test_ds, batch_size=128)

Transfer learning from ImageNet to MoNuSAC

Tissue classification with the MoNuSAC dataset

import torch
import torch.nn as nn

dl_model = torchvision.models.inception_v3(
    inception_weights,
    progress=True
)

dl_model.fc = nn.Identity()

dl_model.eval()

Tissue classification with the MoNuSAC dataset

dl_model.fc = nn.Sequential(
    nn.Linear(in_features=2048, out_features=32, bias=True),
    nn.ReLU(),
    nn.Linear(in_features=32, out_features=4, bias=True)
)
for param in dl_model.parameters():
    param.requires_grad = False

for param in dl_model.fc.parameters():
    param.requires_grad = True

Tissue classification with the MoNuSAC dataset

import torch.optim as optim

if torch.cuda.is_available():
    dl_model.cuda()

optimizer = optim.Adam(dl_model.fc.parameters(), lr=0.001, weight_decay=0.0001)

loss_fun = nn.CrossEntropyLoss()

Tissue classification with the MoNuSAC dataset

from torchmetrics.classification import Accuracy

train_acc_metric = Accuracy(task="multiclass", num_classes=4)
val_acc_metric = Accuracy(task="multiclass", num_classes=4)

if torch.cuda.is_available():
    train_acc_metric.cuda()
    val_acc_metric.cuda()

num_epochs = 100
for e in range(num_epochs):
    avg_train_loss = 0
    total_train_samples = 0

    dl_model.train()
    for x, y in train_dl:
        optimizer.zero_grad()

        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        y_hat = dl_model(x).logits

        loss = loss_fun(y_hat, y)

        loss.backward()

        optimizer.step()

        avg_train_loss += loss.item() * len(y)
        total_train_samples += len(y)

        train_acc_metric(y_hat.softmax(dim=1), y)

    avg_train_loss /= total_train_samples
    train_acc = train_acc_metric.compute()

    avg_val_loss = 0
    total_val_samples = 0

    dl_model.eval()
    with torch.no_grad():
        for x, y in val_dl:
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()

            y_hat = dl_model(x)

            loss = loss_fun(y_hat, y)

            avg_val_loss += loss.item() * len(y)
            total_val_samples += len(y)

            val_acc_metric(y_hat.softmax(dim=1), y)

    avg_val_loss /= total_val_samples
    val_acc = val_acc_metric.compute()

    print(f"[{(e + 1) / num_epochs: 2.2%}] Train loss={avg_train_loss: 2.4} (Acc={train_acc: 2.2%}), Validation loss={avg_val_loss: 2.4} (Acc={val_acc: 2.2%})")

    train_acc_metric.reset()
    val_acc_metric.reset()
[ 1.00%] Train loss= 1.435 (Acc= 29.76%), Validation loss= 1.458 (Acc= 21.95%)
[ 2.00%] Train loss= 1.357 (Acc= 28.57%), Validation loss= 1.408 (Acc= 24.39%)
[ 3.00%] Train loss= 1.31 (Acc= 47.62%), Validation loss= 1.39 (Acc= 26.83%)
[ 4.00%] Train loss= 1.291 (Acc= 33.33%), Validation loss= 1.455 (Acc= 24.39%)
[ 5.00%] Train loss= 1.239 (Acc= 41.07%), Validation loss= 1.364 (Acc= 31.71%)
[ 6.00%] Train loss= 1.223 (Acc= 61.31%), Validation loss= 1.381 (Acc= 34.15%)
[ 7.00%] Train loss= 1.19 (Acc= 57.74%), Validation loss= 1.406 (Acc= 39.02%)
[ 8.00%] Train loss= 1.166 (Acc= 53.57%), Validation loss= 1.428 (Acc= 36.59%)
[ 9.00%] Train loss= 1.088 (Acc= 64.88%), Validation loss= 1.37 (Acc= 29.27%)
[ 10.00%] Train loss= 1.069 (Acc= 67.86%), Validation loss= 1.352 (Acc= 39.02%)
[ 11.00%] Train loss= 1.011 (Acc= 65.48%), Validation loss= 1.358 (Acc= 43.90%)
[ 12.00%] Train loss= 0.9853 (Acc= 66.07%), Validation loss= 1.382 (Acc= 36.59%)
[ 13.00%] Train loss= 0.946 (Acc= 73.81%), Validation loss= 1.325 (Acc= 41.46%)
[ 14.00%] Train loss= 0.9135 (Acc= 75.60%), Validation loss= 1.329 (Acc= 41.46%)
[ 15.00%] Train loss= 0.8952 (Acc= 75.00%), Validation loss= 1.306 (Acc= 41.46%)
[ 16.00%] Train loss= 0.8605 (Acc= 76.79%), Validation loss= 1.299 (Acc= 43.90%)
[ 17.00%] Train loss= 0.8389 (Acc= 76.79%), Validation loss= 1.275 (Acc= 41.46%)
[ 18.00%] Train loss= 0.8207 (Acc= 76.19%), Validation loss= 1.357 (Acc= 39.02%)
[ 19.00%] Train loss= 0.7796 (Acc= 73.81%), Validation loss= 1.31 (Acc= 43.90%)
[ 20.00%] Train loss= 0.6983 (Acc= 86.31%), Validation loss= 1.325 (Acc= 41.46%)
[ 21.00%] Train loss= 0.7003 (Acc= 81.55%), Validation loss= 1.351 (Acc= 41.46%)
[ 22.00%] Train loss= 0.6698 (Acc= 85.12%), Validation loss= 1.3 (Acc= 43.90%)
[ 23.00%] Train loss= 0.7269 (Acc= 74.40%), Validation loss= 1.33 (Acc= 43.90%)
[ 24.00%] Train loss= 0.696 (Acc= 81.55%), Validation loss= 1.325 (Acc= 36.59%)
[ 25.00%] Train loss= 0.6905 (Acc= 80.36%), Validation loss= 1.51 (Acc= 39.02%)
[ 26.00%] Train loss= 0.5891 (Acc= 79.76%), Validation loss= 1.303 (Acc= 43.90%)
[ 27.00%] Train loss= 0.635 (Acc= 76.79%), Validation loss= 1.276 (Acc= 36.59%)
[ 28.00%] Train loss= 0.5582 (Acc= 86.90%), Validation loss= 1.47 (Acc= 39.02%)
[ 29.00%] Train loss= 0.6385 (Acc= 75.00%), Validation loss= 1.395 (Acc= 39.02%)
[ 30.00%] Train loss= 0.5849 (Acc= 82.74%), Validation loss= 1.29 (Acc= 41.46%)
[ 31.00%] Train loss= 0.6524 (Acc= 79.76%), Validation loss= 1.51 (Acc= 36.59%)
[ 32.00%] Train loss= 0.6003 (Acc= 80.95%), Validation loss= 1.336 (Acc= 41.46%)
[ 33.00%] Train loss= 0.4973 (Acc= 87.50%), Validation loss= 1.292 (Acc= 46.34%)
[ 34.00%] Train loss= 0.5211 (Acc= 83.93%), Validation loss= 1.427 (Acc= 46.34%)
[ 35.00%] Train loss= 0.4815 (Acc= 86.31%), Validation loss= 1.304 (Acc= 46.34%)
[ 36.00%] Train loss= 0.522 (Acc= 86.90%), Validation loss= 1.346 (Acc= 46.34%)
[ 37.00%] Train loss= 0.5074 (Acc= 80.95%), Validation loss= 1.431 (Acc= 43.90%)
[ 38.00%] Train loss= 0.4129 (Acc= 91.67%), Validation loss= 1.297 (Acc= 41.46%)
[ 39.00%] Train loss= 0.517 (Acc= 85.12%), Validation loss= 1.315 (Acc= 41.46%)
[ 40.00%] Train loss= 0.4432 (Acc= 85.71%), Validation loss= 1.414 (Acc= 41.46%)
[ 41.00%] Train loss= 0.4182 (Acc= 88.10%), Validation loss= 1.33 (Acc= 43.90%)
[ 42.00%] Train loss= 0.4394 (Acc= 88.10%), Validation loss= 1.345 (Acc= 41.46%)
[ 43.00%] Train loss= 0.4791 (Acc= 83.93%), Validation loss= 1.416 (Acc= 43.90%)
[ 44.00%] Train loss= 0.4631 (Acc= 85.71%), Validation loss= 1.45 (Acc= 39.02%)
[ 45.00%] Train loss= 0.4765 (Acc= 80.95%), Validation loss= 1.37 (Acc= 46.34%)
[ 46.00%] Train loss= 0.4197 (Acc= 87.50%), Validation loss= 1.345 (Acc= 41.46%)
[ 47.00%] Train loss= 0.4168 (Acc= 85.71%), Validation loss= 1.396 (Acc= 41.46%)
[ 48.00%] Train loss= 0.4388 (Acc= 88.10%), Validation loss= 1.397 (Acc= 43.90%)
[ 49.00%] Train loss= 0.4328 (Acc= 88.10%), Validation loss= 1.459 (Acc= 41.46%)
[ 50.00%] Train loss= 0.316 (Acc= 94.64%), Validation loss= 1.349 (Acc= 46.34%)
[ 51.00%] Train loss= 0.3433 (Acc= 88.69%), Validation loss= 1.355 (Acc= 48.78%)
[ 52.00%] Train loss= 0.3498 (Acc= 93.45%), Validation loss= 1.412 (Acc= 48.78%)
[ 53.00%] Train loss= 0.4013 (Acc= 87.50%), Validation loss= 1.493 (Acc= 46.34%)
[ 54.00%] Train loss= 0.371 (Acc= 87.50%), Validation loss= 1.328 (Acc= 48.78%)
[ 55.00%] Train loss= 0.3857 (Acc= 89.29%), Validation loss= 1.527 (Acc= 51.22%)
[ 56.00%] Train loss= 0.4304 (Acc= 83.93%), Validation loss= 1.565 (Acc= 48.78%)
[ 57.00%] Train loss= 0.4279 (Acc= 84.52%), Validation loss= 1.412 (Acc= 41.46%)
[ 58.00%] Train loss= 0.3969 (Acc= 85.71%), Validation loss= 1.428 (Acc= 41.46%)
[ 59.00%] Train loss= 0.4317 (Acc= 85.12%), Validation loss= 1.45 (Acc= 46.34%)
[ 60.00%] Train loss= 0.3444 (Acc= 89.29%), Validation loss= 1.392 (Acc= 48.78%)
[ 61.00%] Train loss= 0.3795 (Acc= 89.88%), Validation loss= 1.534 (Acc= 43.90%)
[ 62.00%] Train loss= 0.3485 (Acc= 88.10%), Validation loss= 1.568 (Acc= 34.15%)
[ 63.00%] Train loss= 0.3696 (Acc= 89.88%), Validation loss= 1.35 (Acc= 51.22%)
[ 64.00%] Train loss= 0.3626 (Acc= 88.69%), Validation loss= 1.393 (Acc= 41.46%)
[ 65.00%] Train loss= 0.3637 (Acc= 89.29%), Validation loss= 1.425 (Acc= 39.02%)
[ 66.00%] Train loss= 0.3611 (Acc= 88.10%), Validation loss= 1.518 (Acc= 46.34%)
[ 67.00%] Train loss= 0.3005 (Acc= 91.07%), Validation loss= 1.632 (Acc= 34.15%)
[ 68.00%] Train loss= 0.3463 (Acc= 89.29%), Validation loss= 1.438 (Acc= 51.22%)
[ 69.00%] Train loss= 0.3334 (Acc= 91.07%), Validation loss= 1.385 (Acc= 48.78%)
[ 70.00%] Train loss= 0.3236 (Acc= 89.88%), Validation loss= 1.557 (Acc= 46.34%)
[ 71.00%] Train loss= 0.3246 (Acc= 89.29%), Validation loss= 1.571 (Acc= 34.15%)
[ 72.00%] Train loss= 0.2648 (Acc= 91.07%), Validation loss= 1.613 (Acc= 39.02%)
[ 73.00%] Train loss= 0.3297 (Acc= 94.05%), Validation loss= 1.544 (Acc= 43.90%)
[ 74.00%] Train loss= 0.3139 (Acc= 91.67%), Validation loss= 1.499 (Acc= 41.46%)
[ 75.00%] Train loss= 0.3491 (Acc= 88.10%), Validation loss= 1.674 (Acc= 41.46%)
[ 76.00%] Train loss= 0.3893 (Acc= 85.71%), Validation loss= 1.603 (Acc= 43.90%)
[ 77.00%] Train loss= 0.3154 (Acc= 88.69%), Validation loss= 1.612 (Acc= 48.78%)
[ 78.00%] Train loss= 0.4907 (Acc= 83.93%), Validation loss= 1.438 (Acc= 39.02%)
[ 79.00%] Train loss= 0.2842 (Acc= 92.86%), Validation loss= 2.057 (Acc= 29.27%)
[ 80.00%] Train loss= 0.4921 (Acc= 81.55%), Validation loss= 1.594 (Acc= 41.46%)
[ 81.00%] Train loss= 0.3356 (Acc= 89.29%), Validation loss= 1.426 (Acc= 41.46%)
[ 82.00%] Train loss= 0.2576 (Acc= 93.45%), Validation loss= 1.452 (Acc= 48.78%)
[ 83.00%] Train loss= 0.3234 (Acc= 89.29%), Validation loss= 1.633 (Acc= 39.02%)
[ 84.00%] Train loss= 0.2623 (Acc= 90.48%), Validation loss= 1.415 (Acc= 51.22%)
[ 85.00%] Train loss= 0.3058 (Acc= 89.29%), Validation loss= 1.477 (Acc= 48.78%)
[ 86.00%] Train loss= 0.2876 (Acc= 89.88%), Validation loss= 1.536 (Acc= 41.46%)
[ 87.00%] Train loss= 0.3292 (Acc= 89.29%), Validation loss= 1.528 (Acc= 41.46%)
[ 88.00%] Train loss= 0.2311 (Acc= 92.86%), Validation loss= 1.574 (Acc= 43.90%)
[ 89.00%] Train loss= 0.2575 (Acc= 93.45%), Validation loss= 1.494 (Acc= 48.78%)
[ 90.00%] Train loss= 0.3032 (Acc= 90.48%), Validation loss= 1.603 (Acc= 39.02%)
[ 91.00%] Train loss= 0.3307 (Acc= 89.88%), Validation loss= 1.552 (Acc= 41.46%)
[ 92.00%] Train loss= 0.3068 (Acc= 90.48%), Validation loss= 1.69 (Acc= 43.90%)
[ 93.00%] Train loss= 0.3022 (Acc= 88.10%), Validation loss= 1.665 (Acc= 46.34%)
[ 94.00%] Train loss= 0.3037 (Acc= 91.07%), Validation loss= 1.711 (Acc= 39.02%)
[ 95.00%] Train loss= 0.2834 (Acc= 92.86%), Validation loss= 1.71 (Acc= 39.02%)
[ 96.00%] Train loss= 0.3441 (Acc= 89.88%), Validation loss= 1.539 (Acc= 41.46%)
[ 97.00%] Train loss= 0.2494 (Acc= 92.26%), Validation loss= 1.643 (Acc= 41.46%)
[ 98.00%] Train loss= 0.3221 (Acc= 88.10%), Validation loss= 1.902 (Acc= 36.59%)
[ 99.00%] Train loss= 0.2736 (Acc= 91.67%), Validation loss= 1.635 (Acc= 36.59%)
[ 100.00%] Train loss= 0.2898 (Acc= 88.10%), Validation loss= 1.58 (Acc= 39.02%)

Tissue classification with the MoNuSAC dataset

avg_test_loss = 0
total_test_samples = 0

test_acc_metric = Accuracy(task="multiclass", num_classes=4)

if torch.cuda.is_available():
    test_acc_metric.cuda()

dl_model.eval()
with torch.no_grad():
    for x, y in test_dl:
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        y_hat = dl_model(x)

        loss = loss_fun(y_hat, y)

        avg_test_loss += loss.item() * len(y)
        total_test_samples += len(y)

        test_acc_metric(y_hat.softmax(dim=1), y)

avg_test_loss /= total_test_samples
test_acc = test_acc_metric.compute()

print(f"Test loss={avg_test_loss: 2.4} (Acc={test_acc: 2.2%})")

test_acc_metric.reset()
Test loss= 1.737 (Acc= 37.62%)

Tissue classification with the MoNuSAC dataset

checkpoint = dl_model.state_dict()
torch.save(checkpoint, "monusac_checkpoint.pt")

Explore the MoNuSAC dataset in the embedded feature space

Explore the MoNuSAC dataset in the embedded feature space

import numpy as np

if torch.cuda.is_available():
    dl_model.cuda()
    dl_model.fc.cuda()

train_features = []
train_labels = []

dl_model.fc[-1] = nn.Identity()

dl_model.eval()
with torch.no_grad():
    for x, y in train_dl:
        if torch.cuda.is_available():
            x = x.cuda()

        fx = dl_model(x).cpu().detach().numpy()

        train_features.append(fx)
        train_labels.append(y.numpy())

train_features = np.concatenate(train_features, 0)
train_labels = np.concatenate(train_labels, 0)

Explore the MoNuSAC dataset in the embedded feature space

import umap

reducer = umap.UMAP()

embedding = reducer.fit_transform(train_features)

embedding.shape

Explore the MoNuSAC dataset in the embedded feature space

Code
import matplotlib.pyplot as plt

emb_plot = plt.scatter(embedding[:, 0], embedding[:, 1], c=train_labels, marker="o")

plt.legend(handles=emb_plot.legend_elements()[0], labels=tissue_classes)
plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP projection of InceptionV3 features of the MoNuSAC dataset', fontsize=24)
Text(0.5, 1.0, 'UMAP projection of InceptionV3 features of the MoNuSAC dataset')

Explore the MoNuSAC dataset in the embedded feature space

val_features = []
val_labels = []

test_features = []
test_labels = []

with torch.no_grad():
    for x, y in val_dl:
        if torch.cuda.is_available():
            x = x.cuda()

        fx = dl_model(x).cpu().detach().numpy()

        val_features.append(fx)
        val_labels.append(y.numpy())

val_features = np.concatenate(val_features, 0)
val_labels = np.concatenate(val_labels, 0)

with torch.no_grad():
    for x, y in test_dl:
        if torch.cuda.is_available():
            x = x.cuda()

        fx = dl_model(x).cpu().detach().numpy()

        test_features.append(fx)
        test_labels.append(y.numpy())

test_features = np.concatenate(test_features, 0)
test_labels = np.concatenate(test_labels, 0)

Explore the MoNuSAC dataset in the embedded feature space

embedding_val = reducer.transform(val_features)
embedding_test = reducer.transform(test_features)
Code
fig, (ax_0, ax_1) = plt.subplots(1, 2)

emb_train_plot = ax_0.scatter(embedding[:, 0], embedding[:, 1], c=train_labels, marker="o", alpha=0.5)
legend_train = ax_0.legend(handles=emb_train_plot.legend_elements()[0], labels=tissue_classes, loc="lower left", title="Train dataset")
ax_0.add_artist(legend_train)

emb_val_plot = ax_0.scatter(embedding_val[:, 0], embedding_val[:, 1], c=val_labels,  marker="v", alpha=0.5)
legend_val = ax_0.legend(handles=emb_val_plot.legend_elements()[0], labels=tissue_classes, loc="lower right", title="Validation dataset")
ax_0.add_artist(legend_val)

ax_0.set_aspect('equal', 'datalim')

emb_train_plot = ax_1.scatter(embedding[:, 0], embedding[:, 1], c=train_labels, marker="o", alpha=0.5)
legend_train = ax_1.legend(handles=emb_train_plot.legend_elements()[0], labels=tissue_classes, loc="lower left", title="Train dataset")
ax_1.add_artist(legend_train)

emb_test_plot = ax_1.scatter(embedding_test[:, 0], embedding_test[:, 1], c=test_labels, marker="s", alpha=0.5)
legend_test = ax_1.legend(handles=emb_test_plot.legend_elements()[0], labels=tissue_classes, loc="lower right", title="Test dataset")
ax_1.add_artist(legend_test)

ax_1.set_aspect('equal', 'datalim')
plt.title('UMAP projection of InceptionV3 features of the MoNuSAC dataset', fontsize=24)
Text(0.5, 1.0, 'UMAP projection of InceptionV3 features of the MoNuSAC dataset')