2D Classification Example on MedNIST and rising¶
Welcome to this rising example, where we will build a 2D classification pipeline with rising and pyorch lightning. The dataset part of this notebook was inspired by the Monai MedNIST example, so make sure to check them out, too :D
Preparation¶
Let’s start with some basic preparations of our environment and download the MedNIST data.
First, we will install rising’s master branch to get the latest features (if your a not planning to extend rising you can easily install out pypi package with pip install rising
).
[ ]:
!pip install --upgrade --quiet git+https://github.com/PhoenixDL/rising # for data handling
!pip install --upgrade --quiet pytorch-lightning # for easy training
!pip install --upgrade --quiet scikit-learn # for classification metrics
Next, we will add some magic to our notebook in case your are running them locally and do not want refresh it all the time.
[ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
Finally, we download the MedNIST dataset and undpack it.
[ ]:
# Only check after the else statement for the data download :)
if True: # 'CI' in os.environ:
# our notebooks are executed to test our example
# for this we need to create some dummy data
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from PIL import Image
# create dummy data for our CI
base_dir = Path("./MedNIST")
base_dir.mkdir(exist_ok=True)
cls_path1 = base_dir / "AbdomenCT"
cls_path1.mkdir(exist_ok=True)
cls_path2 = base_dir / "BreastMRI"
cls_path2.mkdir(exist_ok=True)
for i in range(100):
np_array = np.zeros((64, 64)).astype(np.uint8)
img = Image.fromarray(np_array)
img.save(cls_path1 / f"img{i}.png")
# plt.imsave(str(cls_path1 / f"img{i}.png"), np_array, cmap='Greys')
for i in range(100):
np_array = np.ones((64, 64)).astype(np.uint8)
img = Image.fromarray(np_array)
img.save(cls_path2 / f"img{i}.png")
# plt.imsave(str(cls_path2 / f"img{i}.png"), np_array, cmap='Greys')
else:
# download MedNIST
!curl -L -o MedNIST.tar.gz 'https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz'
# unzip the '.tar.gz' file to the current directory
import tarfile
datafile = tarfile.open("MedNIST.tar.gz")
datafile.extractall()
datafile.close()
Preparing our datasets¶
If you already wrote your own datasets with PyTorch this well be very familiar because rising
uses the same dataset structure as PyTorch. The only difference between native PyTorch and rising
is the transformation part. While PyTorch embeds its transformation into the dataset, we opted to move the transformations to our dataloder (which is a direct subclass of PyTorch’s dataloader) to make our datasets easily interchangeable between multiple tasks.
Let’s start by searching for the paths of the image files and defining their classes.
[ ]:
import os
from pathlib import Path
from PIL import Image
data_dir = Path('./MedNIST/')
class_names = sorted([p.stem for p in data_dir.iterdir() if p.is_dir()])
num_class = len(class_names)
image_files = [[x for x in (data_dir / class_name).iterdir()] for class_name in class_names]
image_file_list = []
image_label_list = []
for i, class_name in enumerate(class_names):
image_file_list.extend(image_files[i])
image_label_list.extend([i] * len(image_files[i]))
num_total = len(image_label_list)
print('Total image count:', num_total)
print("Label names:", class_names)
print("Label counts:", [len(image_files[i]) for i in range(num_class)])
The output should look like this:
Total image count: 58954
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]
The downloaded data needs to be divided into 3 subsets for training, validation and testing. Because the dataset is fairly large we can opt for an 80/10/10 split.
[ ]:
import numpy as np
valid_frac, test_frac = 0.1, 0.1
trainX, trainY = [], []
valX, valY = [], []
testX, testY = [], []
for i in range(num_total):
rann = np.random.random()
if rann < valid_frac:
valX.append(image_file_list[i])
valY.append(image_label_list[i])
elif rann < test_frac + valid_frac:
testX.append(image_file_list[i])
testY.append(image_label_list[i])
else:
trainX.append(image_file_list[i])
trainY.append(image_label_list[i])
print("Training count =",len(trainX),"Validation count =", len(valX), "Test count =",len(testX))
The MedNIST dataset now just needs to load the specified files. We use PIL to load the individual image file and convert it to a tensor afterwards.
[ ]:
import torch
import numpy as np
from typing import Sequence, Dict
from torch.utils.data import Dataset
class MedNISTDataset(Dataset):
"""
Simple dataset to load individual samples from the dataset
"""
def __init__(self, image_files: Sequence[str], labels: Sequence[int]):
"""
Args:
image_files: paths to the image files
labels: label for each file
"""
assert len(image_files) == len(labels), "Every file needs a label"
self.image_files = image_files
self.labels = labels
def __len__(self) -> int:
"""
Number of samples inside the dataset
Returns:
int: length
"""
return len(self.image_files)
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
"""
Select an individual sample from the dataset
Args:
index: index of sample to draw
Return:
Dict[str, torch.Tensor]: single sample
* `data`: image data
* `label`: label for sample
"""
data_np = np.array(Image.open(self.image_files[index]))
return {"data": torch.from_numpy(data_np)[None].float(),
"label": torch.tensor(self.labels[index]).long()}
train_ds = MedNISTDataset(trainX, trainY)
val_ds = MedNISTDataset(valX, valY)
test_ds = MedNISTDataset(testX, testY)
Let see some basic statistics of a single sample.
[ ]:
print(f'Single image min: {train_ds[0]["data"].min()}')
print(f'Single image max: {train_ds[0]["data"].max()}')
print(f'Single image mean: {train_ds[0]["data"].shape} (C, W, H)')
print(f'Exaple label {train_ds[0]["label"]}')
print(f'Example data: {train_ds[0]["data"]}')
The output could look something like this:
Single image min: 87.0
Single image max: 255.0
Single image mean: torch.Size([1, 64, 64]) (C, W, H)
Exaple label 0
Example data: tensor([[[101., 101., 101., ..., 101., 101., 101.],
[101., 101., 101., ..., 101., 101., 101.],
[101., 101., 101., ..., 101., 101., 101.],
...,
[102., 101., 99., ..., 111., 103., 98.],
[102., 101., 100., ..., 99., 98., 98.],
[ 99., 100., 102., ..., 101., 103., 105.]]])
Setting Up our Dataloading and Transformations¶
In this section we will define our transformations and plug our dataset into the dataloader of rising
.
First we setup our transformation. In general these can be split into two parts: transformations which are applied as preprocessing and transformations which are applied as augmentations. All transformations are applied in a batched fashion to the dataset to fully utilize vectorization to speed up augmentation. In case your dataset needs additional preprocessing on a per sample basis you can also add those to the dataloder with sample_transforms
. Check out or 3D Segmentation
Tutorial for more infroamtion about that.
[ ]:
import rising.transforms as rtr
from rising.random import UniformParameter
transforms_prep = []
transforms_augment = []
# preprocessing transforms
# transforms_prep.append(rtr.NormZeroMeanUnitStd())
transforms_prep.append(rtr.NormMinMax()) # visualization looks nicer :)
# augmentation transforms
transforms_augment.append(rtr.GaussianNoise(0., 0.01))
transforms_augment.append(rtr.GaussianSmoothing(
in_channels=1, kernel_size=3, std=0.5, padding=1))
transforms_augment.append(rtr.Rot90((0, 1)))
transforms_augment.append(rtr.Mirror(dims=(0, 1)))
transforms_augment.append(rtr.BaseAffine(
scale=UniformParameter(0.8, 1.2),
rotation=UniformParameter(-30, 30), degree=True,
# translation in base affine is normalized to image size
# Translation transform offers to option to swith to pixels
translation=UniformParameter(-0.02, 0.02),
))
In contrast to native PyTorch we add our transformations to the dataloder of rising. There are three main types of transformations which can be added: * sample_transforms
: these transforms are applied per sample. In case the transformation assumes a batch of data pseudo_batch_dim
can be activated to automatically add a batch dim to single samples. * batch_transforms
: these transforms are executed per batch inside the multiprocessig context of the CPU (like sample_transforms
).
* gpu_transforms
: these transforms are executed on the GPU. In case you have multiple GPUs make sure to set the correct device
, otherwise rising could use the wrong GPU.
[ ]:
from rising.loading import DataLoader
tr_transform = rtr.Compose(transforms_prep + transforms_augment)
dataloader_tr = DataLoader(train_ds, batch_size=32, shuffle=True,
gpu_transforms=tr_transform)
val_transform = rtr.Compose(transforms_prep)
dataloader_val = DataLoader(val_ds, batch_size=32,
gpu_transforms=val_transform)
test_transform = rtr.Compose(transforms_prep)
dataloader_ts = DataLoader(test_ds, batch_size=32,
gpu_transforms=test_transform)
Looking at some example outputs¶
In this short section we will visualize some of the batches to look at the influence of the augmentations.
[ ]:
# helper function to visualize batches of images
import torch
import torchvision
import matplotlib.pyplot as plt
def show_batch(batch: torch.Tensor, norm: bool = True):
"""
Visualize a single batch of images
Args:
batch: batch of data
norm: normalized to range 0,1 for visualization purposes
"""
grid = torchvision.utils.make_grid(batch.cpu(), nrow=8)
grid -= grid.min()
m = grid.max()
if m > 1e-6:
grid = grid / m
plt.figure(figsize=(10,5))
plt.imshow(grid[0], cmap='gray', vmin=0, vmax=1)
plt.tight_layout()
plt.show()
[ ]:
# make dataset iterable
_iter = iter(dataloader_tr)
[ ]:
# visualize batch of images
batch = next(_iter)
print({f'{key}_shape: {tuple(batch[key].shape)}' for key, item in batch.items()})
print(f'Batch labels: \n{batch["label"]}')
print(f'Batch mean {batch["data"].mean()}')
print(f'Batch min {batch["data"].min()}')
print(f'Batch max {batch["data"].max()}')
show_batch(batch["data"], norm=True)
The output of the visualization could look something like this:
The exact images will vary because the batch was selected from the training dataloader which shuffles the data.
Defining our Lightning Module¶
We will use pytorch-lightning as our trainer framework to save some time and to standardize our pipeline.
In lightning the training models are derived from pytorch_lightning.LightningModule
which enforces a specific structure of the code to increase reproducibility and stardization across the community. For simplicity we will simply load a torchvision model and overwrite the basic *_step
functions of lightning. If you want more information how to build pipelines with pytorch lightning, please check out their documentation.
[ ]:
import torch.nn as nn
import torchvision.models as models
if 'CI' in os.environ:
# use a very small model for CI
class SuperSmallModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
self.pool1 = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, num_class)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = torch.flatten(self.pool1(x), 1)
return self.fc(x)
resnet = SuperSmallModel()
else:
# resnet18 for normal runs
resnet = models.resnet18(pretrained=False)
# change first layer
resnet.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# change last layer
fc_in = resnet.fc.in_features
resnet.fc = nn.Linear(fc_in, num_class)
[ ]:
import torch.nn.functional as F
import pytorch_lightning as pl
from sklearn.metrics import classification_report
from typing import Dict, Optional
class SimpleClassifier(pl.LightningModule):
def __init__(self, hparams: Optional[dict] = None):
"""
Hyperparameters for our model
Args:
hparams: hyperparameters for model
`lr`: learning rate for optimizer
"""
super().__init__()
if hparams is None:
hparams = {}
self.hparams = hparams
self.model = resnet
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward input batch of data through model
Args:
x: input batch of data [N, C, H, W]
N batch size (here 32); C number of channels (here 1);
H,W spatial dimensions of images (here 64x64)
Returns:
torch.Tensor: classification logits [N, num_classes]
"""
return self.model(x)
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
"""
Forward batch and compute loss for a single step (used for training)
Args:
batch: batch to process
`data`: input data
`label`: expected labels
batch_idx: index of batch
"""
x, y = batch["data"], batch["label"]
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
"""
Forward batch and compute loss for a single step (used for validation)
Args:
batch: batch to process
`data`: input data
`label`: expected labels
batch_idx: index of batch
"""
x, y = batch["data"], batch["label"]
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return {'val_loss': val_loss}
def validation_epoch_end(self, outputs):
"""
Compute average validation loss during epoch
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict:
"""
Forward batch and compute loss for a single step (used for validation)
Args:
batch: batch to process
`data`: input data
`label`: expected labels
batch_idx: index of batch
"""
x, y = batch["data"], batch["label"]
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return {'test_loss': val_loss,
"pred_label": y_hat.max(dim=1)[1].detach().cpu(),
"label": y.detach().cpu()}
def test_epoch_end(self, outputs):
"""
Compute average test loss and classification metrics
"""
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_loss': avg_loss}
all_pred_label = torch.cat([x['pred_label'] for x in outputs])
all_label = torch.cat([x['label'] for x in outputs])
print(classification_report(all_label.numpy(),
all_pred_label.numpy(),
target_names=class_names, digits=4))
return {'test_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
"""
Setup optimizer for training
"""
return torch.optim.Adam(self.parameters(), lr=self.hparams.get("lr", 1e-5))
We can visualize our training progress and hyperparameters in tensorboard to easily compare multiple runs of our classifier.
[ ]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
Let’s start our training :D
[ ]:
from pytorch_lightning import Trainer
model = SimpleClassifier()
if torch.cuda.is_available():
gpus = [0]
else:
gpus=None
# most basic trainer, uses good defaults
trainer = Trainer(gpus=gpus, progress_bar_refresh_rate=10, max_epochs=4, weights_summary=None)
trainer.fit(model, train_dataloader=dataloader_tr, val_dataloaders=dataloader_val)
After training our model we can test it on our test data.
[ ]:
trainer.test(test_dataloaders=dataloader_ts)
The results on the test data should look similar to this:
precision recall f1-score support
AbdomenCT 0.9536 0.9990 0.9758 1008
BreastMRI 1.0000 1.0000 1.0000 830
CXR 0.9960 0.9872 0.9916 1013
ChestCT 1.0000 0.9490 0.9738 961
Hand 0.9877 0.9887 0.9882 975
HeadCT 0.9912 1.0000 0.9956 1019
accuracy 0.9873 5806
macro avg 0.9881 0.9873 0.9875 5806
weighted avg 0.9876 0.9873 0.9872 5806