2D Classification Example on MedNIST and rising

Welcome to this rising example, where we will build a 2D classification pipeline with rising and pyorch lightning. The dataset part of this notebook was inspired by the Monai MedNIST example, so make sure to check them out, too :D


Let’s start with some basic preparations of our environment and download the MedNIST data.

First, we will install rising’s master branch to get the latest features (if your a not planning to extend rising you can easily install out pypi package with pip install rising).

[ ]:
!pip install --upgrade --quiet git+ # for data handling
!pip install --upgrade --quiet pytorch-lightning # for easy training
!pip install --upgrade --quiet scikit-learn # for classification metrics

Next, we will add some magic to our notebook in case your are running them locally and do not want refresh it all the time.

[ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

Finally, we download the MedNIST dataset and undpack it.

[ ]:
import os

# Only check after the else statement for the data download :)
if 'CI' in os.environ:
    # our notebooks are executed to test our example
    # for this we need to create some dummy data
    import matplotlib.pyplot as plt
    from pathlib import Path
    import numpy as np
    from PIL import Image

    # create dummy data for our CI
    base_dir = Path("./MedNIST")
    cls_path1 = base_dir / "AbdomenCT"
    cls_path2 = base_dir / "BreastMRI"

    for i in range(100):
        np_array = np.zeros((64, 64)).astype(np.uint8)
        img = Image.fromarray(np_array) / f"img{i}.png")
        # plt.imsave(str(cls_path1 / f"img{i}.png"), np_array, cmap='Greys')
    for i in range(100):
        np_array = np.ones((64, 64)).astype(np.uint8)
        img = Image.fromarray(np_array) / f"img{i}.png")
        # plt.imsave(str(cls_path2 / f"img{i}.png"), np_array, cmap='Greys')
    # download MedNIST
    !curl -L -o MedNIST.tar.gz ''

    # unzip the '.tar.gz' file to the current directory
    import tarfile
    datafile ="MedNIST.tar.gz")

Preparing our datasets

If you already wrote your own datasets with PyTorch this well be very familiar because rising uses the same dataset structure as PyTorch. The only difference between native PyTorch and rising is the transformation part. While PyTorch embeds its transformation into the dataset, we opted to move the transformations to our dataloder (which is a direct subclass of PyTorch’s dataloader) to make our datasets easily interchangeable between multiple tasks.

Let’s start by searching for the paths of the image files and defining their classes.

[ ]:
import os
from pathlib import Path
from PIL import Image

data_dir = Path('./MedNIST/')
class_names = sorted([p.stem for p in data_dir.iterdir() if p.is_dir()])
num_class = len(class_names)

image_files = [[x for x in (data_dir / class_name).iterdir()] for class_name in class_names]

image_file_list = []
image_label_list = []
for i, class_name in enumerate(class_names):
    image_label_list.extend([i] * len(image_files[i]))

num_total = len(image_label_list)

print('Total image count:', num_total)
print("Label names:", class_names)
print("Label counts:", [len(image_files[i]) for i in range(num_class)])

The output should look like this:

Total image count: 58954
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]

The downloaded data needs to be divided into 3 subsets for training, validation and testing. Because the dataset is fairly large we can opt for an 80/10/10 split.

[ ]:
import numpy as np

valid_frac, test_frac = 0.1, 0.1
trainX, trainY = [], []
valX, valY = [], []
testX, testY = [], []

for i in range(num_total):
    rann = np.random.random()
    if rann < valid_frac:
    elif rann < test_frac + valid_frac:

print("Training count =",len(trainX),"Validation count =", len(valX), "Test count =",len(testX))

The MedNIST dataset now just needs to load the specified files. We use PIL to load the individual image file and convert it to a tensor afterwards.

[ ]:
import torch
import numpy as np

from typing import Sequence, Dict
from import Dataset

class MedNISTDataset(Dataset):
  Simple dataset to load individual samples from the dataset

  def __init__(self, image_files: Sequence[str], labels: Sequence[int]):
      image_files: paths to the image files
      labels: label for each file
    assert len(image_files) == len(labels), "Every file needs a label"
    self.image_files = image_files
    self.labels = labels

  def __len__(self) -> int:
    Number of samples inside the dataset

      int: length
    return len(self.image_files)

  def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
    Select an individual sample from the dataset

      index: index of sample to draw

      Dict[str, torch.Tensor]: single sample
        * `data`: image data
        * `label`: label for sample
    data_np = np.array([index]))
    return {"data": torch.from_numpy(data_np)[None].float(),
            "label": torch.tensor(self.labels[index]).long()}

train_ds = MedNISTDataset(trainX, trainY)
val_ds = MedNISTDataset(valX, valY)
test_ds = MedNISTDataset(testX, testY)

Let see some basic statistics of a single sample.

[ ]:
print(f'Single image min: {train_ds[0]["data"].min()}')
print(f'Single image max: {train_ds[0]["data"].max()}')
print(f'Single image mean: {train_ds[0]["data"].shape} (C, W, H)')
print(f'Exaple label {train_ds[0]["label"]}')
print(f'Example data: {train_ds[0]["data"]}')

The output could look something like this:

Single image min: 87.0
Single image max: 255.0
Single image mean: torch.Size([1, 64, 64]) (C, W, H)
Exaple label 0
Example data: tensor([[[101., 101., 101.,  ..., 101., 101., 101.],
         [101., 101., 101.,  ..., 101., 101., 101.],
         [101., 101., 101.,  ..., 101., 101., 101.],
         [102., 101.,  99.,  ..., 111., 103.,  98.],
         [102., 101., 100.,  ...,  99.,  98.,  98.],
         [ 99., 100., 102.,  ..., 101., 103., 105.]]])

Setting Up our Dataloading and Transformations

In this section we will define our transformations and plug our dataset into the dataloader of rising.

First we setup our transformation. In general these can be split into two parts: transformations which are applied as preprocessing and transformations which are applied as augmentations. All transformations are applied in a batched fashion to the dataset to fully utilize vectorization to speed up augmentation. In case your dataset needs additional preprocessing on a per sample basis you can also add those to the dataloder with sample_transforms. Check out or 3D Segmentation Tutorial for more infroamtion about that.

[ ]:
import rising.transforms as rtr
from rising.random import UniformParameter

transforms_prep = []
transforms_augment = []

# preprocessing transforms
# transforms_prep.append(rtr.NormZeroMeanUnitStd())
transforms_prep.append(rtr.NormMinMax()) # visualization looks nicer :)

# augmentation transforms
transforms_augment.append(rtr.GaussianNoise(0., 0.01))
    in_channels=1, kernel_size=3, std=0.5, padding=1))
transforms_augment.append(rtr.Rot90((0, 1)))
transforms_augment.append(rtr.Mirror(dims=(0, 1)))
    scale=UniformParameter(0.8, 1.2),
    rotation=UniformParameter(-30, 30), degree=True,
    # translation in base affine is normalized to image size
    # Translation transform offers to option to swith to pixels
    translation=UniformParameter(-0.02, 0.02),

In contrast to native PyTorch we add our transformations to the dataloder of rising. There are three main types of transformations which can be added: * sample_transforms: these transforms are applied per sample. In case the transformation assumes a batch of data pseudo_batch_dim can be activated to automatically add a batch dim to single samples. * batch_transforms: these transforms are executed per batch inside the multiprocessig context of the CPU (like sample_transforms). * gpu_transforms: these transforms are executed on the GPU. In case you have multiple GPUs make sure to set the correct device, otherwise rising could use the wrong GPU.

[ ]:
from rising.loading import DataLoader

tr_transform = rtr.Compose(transforms_prep + transforms_augment)
dataloader_tr = DataLoader(train_ds, batch_size=32, shuffle=True,

val_transform = rtr.Compose(transforms_prep)
dataloader_val = DataLoader(val_ds, batch_size=32,

test_transform = rtr.Compose(transforms_prep)
dataloader_ts = DataLoader(test_ds, batch_size=32,

Looking at some example outputs

In this short section we will visualize some of the batches to look at the influence of the augmentations.

[ ]:
# helper function to visualize batches of images
import torch
import torchvision
import matplotlib.pyplot as plt

def show_batch(batch: torch.Tensor, norm: bool = True):
  Visualize a single batch of images

    batch: batch of data
    norm: normalized to range 0,1 for visualization purposes
  grid = torchvision.utils.make_grid(batch.cpu(), nrow=8)

  grid -= grid.min()
  m = grid.max()
  if m > 1e-6:
    grid = grid / m

  plt.imshow(grid[0], cmap='gray', vmin=0, vmax=1)
[ ]:
# make dataset iterable
_iter = iter(dataloader_tr)
[ ]:
# visualize batch of images
batch = next(_iter)
print({f'{key}_shape: {tuple(batch[key].shape)}' for key, item in batch.items()})
print(f'Batch labels: \n{batch["label"]}')
print(f'Batch mean {batch["data"].mean()}')
print(f'Batch min {batch["data"].min()}')
print(f'Batch max {batch["data"].max()}')

show_batch(batch["data"], norm=True)

The output of the visualization could look something like this:

Example Batch

The exact images will vary because the batch was selected from the training dataloader which shuffles the data.

Defining our Lightning Module

We will use pytorch-lightning as our trainer framework to save some time and to standardize our pipeline.

In lightning the training models are derived from pytorch_lightning.LightningModule which enforces a specific structure of the code to increase reproducibility and stardization across the community. For simplicity we will simply load a torchvision model and overwrite the basic *_step functions of lightning. If you want more information how to build pipelines with pytorch lightning, please check out their documentation.

[ ]:
import torch.nn as nn
import torchvision.models as models

if 'CI' in os.environ:
    # use a very small model for CI
    class SuperSmallModel(nn.Module):
        def __init__(self):
            self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)
            self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
            self.pool1 = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(32, num_class)

        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = torch.flatten(self.pool1(x), 1)
            return self.fc(x)
    resnet = SuperSmallModel()
    # resnet18 for normal runs
    resnet = models.resnet18(pretrained=False)
    # change first layer
    resnet.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # change last layer
    fc_in = resnet.fc.in_features
    resnet.fc = nn.Linear(fc_in, num_class)
[ ]:
import torch.nn.functional as F
import pytorch_lightning as pl

from sklearn.metrics import classification_report
from typing import Dict, Optional

class SimpleClassifier(pl.LightningModule):
  def __init__(self, hparams: Optional[dict] = None):
    Hyperparameters for our model

      hparams: hyperparameters for model
        `lr`: learning rate for optimizer
    if hparams is None:
        hparams = {}
    self.hparams = hparams
    self.model = resnet

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    Forward input batch of data through model

      x: input batch of data [N, C, H, W]
        N batch size (here 32); C number of channels (here 1);
        H,W spatial dimensions of images (here 64x64)

      torch.Tensor: classification logits [N, num_classes]
    return self.model(x)

  def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
    Forward batch and compute loss for a single step (used for training)

      batch: batch to process
        `data`: input data
        `label`: expected labels
      batch_idx: index of batch
    x, y = batch["data"], batch["label"]
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    tensorboard_logs = {'train_loss': loss}
    return {'loss': loss, 'log': tensorboard_logs}

  def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
    Forward batch and compute loss for a single step (used for validation)

      batch: batch to process
        `data`: input data
        `label`: expected labels
      batch_idx: index of batch
    x, y = batch["data"], batch["label"]
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    return {'val_loss': val_loss}

  def validation_epoch_end(self, outputs):
    Compute average validation loss during epoch
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    tensorboard_logs = {'val_loss': avg_loss}
    return {'val_loss': avg_loss, 'log': tensorboard_logs}

  def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
    Forward batch and compute loss for a single step (used for validation)

      batch: batch to process
        `data`: input data
        `label`: expected labels
      batch_idx: index of batch
    x, y = batch["data"], batch["label"]
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    return {'test_loss': val_loss,
            "pred_label": y_hat.max(dim=1)[1].detach().cpu(),
            "label": y.detach().cpu()}

  def test_epoch_end(self, outputs):
    Compute average test loss and classification metrics
    avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
    tensorboard_logs = {'test_loss': avg_loss}

    all_pred_label =[x['pred_label'] for x in outputs])
    all_label =[x['label'] for x in outputs])
                                target_names=class_names, digits=4))

    return {'test_loss': avg_loss, 'log': tensorboard_logs}

  def configure_optimizers(self):
    Setup optimizer for training
    return torch.optim.Adam(self.parameters(), lr=self.hparams.get("lr", 1e-5))

We can visualize our training progress and hyperparameters in tensorboard to easily compare multiple runs of our classifier.

[ ]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

Let’s start our training :D

[ ]:
from pytorch_lightning import Trainer

model = SimpleClassifier()

if torch.cuda.is_available():
    gpus = [0]

# most basic trainer, uses good defaults
trainer = Trainer(gpus=gpus, progress_bar_refresh_rate=10, max_epochs=4, weights_summary=None), train_dataloader=dataloader_tr, val_dataloaders=dataloader_val)

After training our model we can test it on our test data.

[ ]:

The results on the test data should look similar to this:

              precision    recall  f1-score   support

   AbdomenCT     0.9536    0.9990    0.9758      1008
   BreastMRI     1.0000    1.0000    1.0000       830
         CXR     0.9960    0.9872    0.9916      1013
     ChestCT     1.0000    0.9490    0.9738       961
        Hand     0.9877    0.9887    0.9882       975
      HeadCT     0.9912    1.0000    0.9956      1019

    accuracy                         0.9873      5806
   macro avg     0.9881    0.9873    0.9875      5806
weighted avg     0.9876    0.9873    0.9872      5806

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision ca0cf77f.

Read the Docs v: v0.2.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.