Segmentation with rising and PytorchLightning

This example will show you how to build a proper training pipeline with PyTorch Lightning and rising. But first let’s configure this notebook correctly:

[ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

and install all our dependencies:

[ ]:
!pip install --upgrade --quiet pytorch-lightning # for training
!pip install --upgrade --quiet git+ # for data handling
!pip install --upgrade --quiet SimpleITK # for loading medical data
!pip install --upgrade --quiet tensorboard # for monitoring training
!pip install --upgrade --quiet gdown # to download data cross platform


Once this is done, we need to take care of our training data. To show risings full capabilities, we will be using 3D data from medical decathlon (specifically Task 4: Hippocampus).


We will use the data provided on Google Drive and download it:

[ ]:
import os
import SimpleITK as sitk
import json
import tempfile
import numpy as np
import tarfile
import time
import gdown

temp_dir = tempfile.mkdtemp()

# generate dummy data for ci/cd
if 'CI' in os.environ:
    data_dir = os.path.join(temp_dir, 'DummyData')
    os.makedirs(os.path.join(data_dir, 'training'), exist_ok=True)
    data_paths = []

    for idx in range(50):
        img = np.random.randint(-500, 500, (32, 64, 32), np.int16)
        mask = np.random.randint(0, 1, (32, 64, 32), np.int16)

        img_path = os.path.join(data_dir, 'training', 'img_%03d.nii.gz' % idx)
        mask_path = os.path.join(data_dir, 'training', 'mask_%03d.nii.gz' % idx)
        sitk.WriteImage(sitk.GetImageFromArray(img), img_path)
        sitk.WriteImage(sitk.GetImageFromArray(mask), mask_path)

        data_paths.append({'image': img_path, 'label': mask_path})

    with open(os.path.join(data_dir, 'dataset.json'), 'w') as f:
        json.dump({'training': data_paths}, f, sort_keys=True, indent=4)


    data_url = ""

    data_dir = os.path.join(temp_dir, 'Task04_Hippocampus')
    download_path = os.path.join(temp_dir, 'data.tar'), download_path)


Great! We got our data. Now we can work on loading it. For loading data, rising follows the same principle as PyTorch: It separates the dataset, which provides the logic of loading a single sample, from the dataloader for automatted handling of parallel loading and batching.

In fact we at rising thought that there is no need to reinvent the wheel. This is why we internally use PyTorch’s data structure and just extend it a bit. We’ll come to these extensions later.


Our dataset is fairly simple. It just loads the Nifti Data we downloaded before and returns each sample as a dict:

[ ]:
import SimpleITK as sitk
import json
from rising import loading
from rising.loading import Dataset
import torch
class NiiDataset(Dataset):
    def __init__(self, train: bool, data_dir: str):
            train: whether to use the training or the validation split
            data_dir: directory containing the data
        with open(os.path.join(data_dir, 'dataset.json')) as f:
            content = json.load(f)['training']

            num_train_samples = int(len(content) * 0.9)

            # Split train data into training and validation,
            # since test data contains no ground truth
            if train:
                data = content[:num_train_samples]
                data = content[num_train_samples:]

   = data
            self.data_dir = data_dir

    def __getitem__(self, item: int) -> dict:
        Loads and Returns a single sample

            item: index specifying which item to load

            dict: the loaded sample
        sample =[item]
        img = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['image'])))

        # add channel dim if necesary
        if img.ndim == 3:
            img = img[None]

        label = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['label'])))

        # convert multiclass to binary task by combining all positives
        label = label > 0

        # add channel dim if necessary
        if label.ndim == 3:
            label = label[None]
        return {'data': torch.from_numpy(img).float(),
                'label': torch.from_numpy(label).float()}

    def __len__(self) -> int:
        Adds a length to the dataset

            int: dataset's length
        return len(

For compatibility each rising dataset must hold the same attributes as a PyTorch dataset. This basically comes down to be indexeable. This means, each Sequence-like data (e.g. lists, tuples, tensors or arrays) could also directly be used as a dataset. Ideally each dataset also has a length, since the dataloader tries to use this length to calculate/estimate its own length.

Integration With PyTorch Lightning: Model and Training

After obtaining our data and implementing a way to load it, we now need a model we can train. For this, we will use a fairly simple implementation of the U-Net, which basically is an encoder-decoder network with skip connections. In Lightning all modules should be derived from a LightningModule, which itself is a subclass of the torch.nn.Module. For further details on the LightningModule please refer to the project itself or it’s documentation.


For now we will only define the network’s logic and omit the training logic, which we’ll add later.

[ ]:
import pytorch_lightning as pl
import torch

class Unet(pl.LightningModule):
    """Simple U-Net without training logic"""
    def __init__(self, hparams: dict):
            hparams: the hyperparameters needed to construct the network.
                Specifically these are:
                * start_filts (int)
                * depth (int)
                * in_channels (int)
                * num_classes (int)
        # 4 downsample layers
        out_filts = hparams.get('start_filts', 16)
        depth = hparams.get('depth', 3)
        in_filts = hparams.get('in_channels', 1)
        num_classes = hparams.get('num_classes', 2)

        for idx in range(depth):
            down_block = torch.nn.Sequential(
                torch.nn.Conv3d(in_filts, out_filts, kernel_size=3, padding=1),
                torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1),
            in_filts = out_filts
            out_filts *= 2

            setattr(self, 'down_block_%d' % idx, down_block)

        out_filts = out_filts // 2
        in_filts = in_filts // 2
        out_filts, in_filts = in_filts, out_filts

        for idx in range(depth-1):
            up_block = torch.nn.Sequential(
                torch.nn.Conv3d(in_filts + out_filts, out_filts, kernel_size=3, padding=1),
                torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1),

            in_filts = out_filts
            out_filts = out_filts // 2

            setattr(self, 'up_block_%d' % idx, up_block)

        self.final_conv = torch.nn.Conv3d(in_filts, num_classes, kernel_size=1)
        self.max_pool = torch.nn.MaxPool3d(2, stride=2)
        self.up_sample = torch.nn.Upsample(scale_factor=2)
        self.hparams = hparams

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        Forwards the :attr`input_tensor` through the network to obtain a prediction

            input_tensor: the network's input

            torch.Tensor: the networks output given the :attr`input_tensor`
        depth = self.hparams.get('depth', 3)

        intermediate_outputs = []

        # Compute all the encoder blocks' outputs
        for idx in range(depth):
            intermed = getattr(self, 'down_block_%d' % idx)(input_tensor)
            if idx < depth - 1:
                # store intermediate values for usage in decoder
                input_tensor = getattr(self, 'max_pool')(intermed)
                input_tensor = intermed

        # Compute all the decoder blocks' outputs
        for idx in range(depth-1):
            input_tensor = getattr(self, 'up_sample')(input_tensor)

            # use intermediate values from encoder
            from_down = intermediate_outputs.pop(-1)
            intermed =[input_tensor, from_down], dim=1)
            input_tensor = getattr(self, 'up_block_%d' % idx)(intermed)

        return getattr(self, 'final_conv')(input_tensor)

Okay, that was easy, right? Now let’s just check if everything in our network is fine:

[ ]:
net = Unet({'num_classes': 2, 'in_channels': 1, 'depth': 2, 'start_filts': 2})
print(net(torch.rand(1, 1, 16, 16, 16)).shape)

So what did we do here? We initialized a network accepting input images with one channel. This network will then predict a segmentation map for 2 classes (of which one is the background class). It does so with 3 resolution stages.

When we tested the network, we forwarded a tensor with random values of size (1, 1, 16, 16, 16) through it. The first 1 here is the batch dim, the second 1 the channel dim (as we specified one input channel) and the three 16 are the spatial dimension (depth, height and width).

The output has the same dimensons except the channel dimension now holding 2 channels (one per class).

Training Criterions and Metrics

For training we will use the combination of CrossEntropyLoss and the SoftDiceLoss (see below).

For more details on this, I’d recommend Jeremy Jordan’s Blog on semantic segmentation.

[ ]:
import rising
from typing import Sequence, Optional, Union
import torch

# Taken from
class SoftDiceLoss(torch.nn.Module):
    """Soft Dice Loss"""
    def __init__(self, square_nom: bool = False,
                 square_denom: bool = False,
                 weight: Optional[Union[Sequence, torch.Tensor]] = None,
                 smooth: float = 1.):
            square_nom: whether to square the nominator
            square_denom: whether to square the denominator
            weight: additional weighting of individual classes
            smooth: smoothing for nominator and denominator

        self.square_nom = square_nom
        self.square_denom = square_denom

        self.smooth = smooth

        if weight is not None:
            if not isinstance(weight, torch.Tensor):
                weight = torch.tensor(weight)

            self.register_buffer("weight", weight)
            self.weight = None

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        Computes SoftDice Loss

            predictions: the predictions obtained by the network
            targets: the targets (ground truth) for the :attr:`predictions`

            torch.Tensor: the computed loss value
        # number of classes for onehot
        n_classes = predictions.shape[1]
        with torch.no_grad():
            targets_onehot =
                targets.unsqueeze(1), num_classes=n_classes)
        # sum over spatial dimensions
        dims = tuple(range(2, predictions.dim()))

        # compute nominator
        if self.square_nom:
            nom = torch.sum((predictions * targets_onehot.float()) ** 2, dim=dims)
            nom = torch.sum(predictions * targets_onehot.float(), dim=dims)
        nom = 2 * nom + self.smooth

        # compute denominator
        if self.square_denom:
            i_sum = torch.sum(predictions ** 2, dim=dims)
            t_sum = torch.sum(targets_onehot ** 2, dim=dims)
            i_sum = torch.sum(predictions, dim=dims)
            t_sum = torch.sum(targets_onehot, dim=dims)

        denom = i_sum + t_sum.float() + self.smooth

        # compute loss
        frac = nom / denom

        # apply weight for individual classesproperly
        if self.weight is not None:
            frac = self.weight * frac

        # average over classes
        frac = - torch.mean(frac, dim=1)

        return frac

Okay, now that we are able to properly calculate the loss function, we still lack a metric to monitor, that describes our performance. For segmentation tasks, this usually comes down to the dice coefficient. So let’s implement this one as well:

[ ]:
# Taken from
def binary_dice_coefficient(pred: torch.Tensor, gt: torch.Tensor,
                            thresh: float = 0.5, smooth: float = 1e-7) -> torch.Tensor:
    computes the dice coefficient for a binary segmentation task

        pred: predicted segmentation (of shape Nx(Dx)HxW)
        gt: target segmentation (of shape NxCx(Dx)HxW)
        thresh: segmentation threshold
        smooth: smoothing value to avoid division by zero

        torch.Tensor: dice score

    assert pred.shape == gt.shape

    pred_bool = pred > thresh

    intersec = (pred_bool * gt).float()
    return 2 * intersec.sum() / (pred_bool.float().sum()
                                 + gt.float().sum() + smooth)

Neat! So far we defined all criterions and metrics necessary for proper training and monitoring. But there are still two major parts of our pipeline missing:

1.) Data Preprocessing and Augmentation

2.) what to do for parameter update

Let’s deal with the first point now.

Data Preprocessing

Since all samples in our dataset are of different size, we cannot collate them to a batch directly. Instead we need to resize them. Frameworks like torchvision do this inside the dataset. With rising however, we opted for moving this part outside the dataset (but still apply it on each sample separately before batching) for these reasons.

1.) The dataset get’s more reusable for different settings

2.) The transforms don’t have to be implemented into each dataset, which means it is easier to switch datasets without code duplication

3.) Applying different transforms is as easy as changing an argument of the loader; no need to deal with this manually in the dataset

This kind of transforms kann be passed to the dataloader with sample_transforms. If you have an implementation that usually works on batched data, we got you. All you need to do is specifying pseudo_batch_dim and we will take care of the rest. We will then automatically add a pseudo batch dim to all kind of data (tensors, arrays and all kind of built-in python containers containing a mixture thereof) before applying these transforms and remove it afterwards.

For now, we use our batched implementation of native torch resizing:

[ ]:
from rising.transforms import Compose, ResizeNative

def common_per_sample_trafos():
        return Compose(ResizeNative(size=(32, 64, 32), keys=('data',), mode='trilinear'),
                       ResizeNative(size=(32, 64, 32), keys=('label',), mode='nearest'))

Data Augmentation

Now that we have defined our preprocessing, let’s come to data augmentation. To enrich our dataset, we randomly apply an affine. While rising already contains an implementation of Affine transforms that can also handle random inputs pretty well, we will implement a basic random parameter sampling by ourselves, since this also serves as an educational example.

Basically this is really straight forward. We just derive the BaseAffine class, overwrite the way the matrix is assembled by adding the sampling before we call the actual assembly method. We leave the rest to the already defined class:

[ ]:
from rising.transforms.affine import BaseAffine
import random
from typing import Optional, Sequence

class RandomAffine(BaseAffine):
    """Base Affine with random parameters for scale, rotation and translation"""
    def __init__(self, scale_range: Optional[tuple] = None,
                 rotation_range: Optional[tuple] = None,
                 translation_range: Optional[tuple] = None,
                 degree: bool = True,
                 image_transform: bool = True,
                 keys: Sequence = ('data',),
                 grad: bool = False,
                 output_size: Optional[tuple] = None,
                 adjust_size: bool = False,
                 interpolation_mode: str = 'nearest',
                 padding_mode: str = 'zeros',
                 align_corners: bool = False,
                 reverse_order: bool = False,

            scale_range: tuple containing minimum and maximum values for scale.
                Actual values will be sampled from uniform distribution with these
            rotation_range: tuple containing minimum and maximum values for rotation.
                Actual values will be sampled from uniform distribution with these
            translation_range: tuple containing minimum and maximum values for translation.
                Actual values will be sampled from uniform distribution with these
            keys: keys which should be augmented
            grad: enable gradient computation inside transformation
            degree: whether the given rotation(s) are in degrees.
                Only valid for rotation parameters, which aren't passed
                as full transformation matrix.
            output_size: if given, this will be the resulting image size.
                Defaults to ``None``
            adjust_size: if True, the resulting image size will be
                calculated dynamically to ensure that the whole image fits.
            interpolation_mode: interpolation mode to calculate output values
                ``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'``
            padding_mode: padding mode for outside grid values
                ``'zeros'`` | ``'border'`` | ``'reflection'``.
                Default: ``'zeros'``
            align_corners: Geometrically, we consider the pixels of the
                input as squares rather than points. If set to True,
                the extrema (-1 and 1)  are considered as referring to the
                center points of the input’s corner pixels. If set to False,
                they are instead considered as referring to the corner points
                of the input’s corner pixels, making the sampling more
                resolution agnostic.
            reverse_order: reverses the coordinate order of the
                transformation to conform to the pytorch convention:
                transformation params order [W,H(,D)] and
                batch order [(D,)H,W]
            **kwargs: additional keyword arguments passed to the
                affine transf
        super().__init__(scale=None, rotation=None, translation=None,

        self.scale_range = scale_range
        self.rotation_range = rotation_range
        self.translation_range = translation_range

    def assemble_matrix(self, **data) -> torch.Tensor:
        Samples Parameters for scale, rotation and translation
        before actual matrix assembly.

            **data: dictionary containing a batch

            torch.Tensor: assembled affine matrix
        ndim = data[self.keys[0]].ndim - 2

        if self.scale_range is not None:
            self.scale = [random.uniform(*self.scale_range) for _ in range(ndim)]

        if self.translation_range is not None:
            self.translation = [random.uniform(*self.translation_range) for _ in range(ndim)]

        if self.rotation_range is not None:
            if ndim == 3:
                self.rotation = [random.uniform(*self.rotation_range) for _ in range(ndim)]
            elif ndim == 1:
                self.rotation = random.uniform(*self.rotation_range)

        return super().assemble_matrix(**data)

Also not that hard… So, now we have a custom implementation of a randomly parametrized affine transformation. This is all we will use as data augmentation for now.

Batched Transforms that shall be executed on CPU in a multiprocessed way should be specified to the dataloader as batch_transforms. If they should be executed on GPU, you can pass them as gpu_transforms. Unfortnuately it is not possible to add GPU transforms in a multiprocessing environment. Thus the internal computation order is like this:

1.) Extract sample from dataset

2.) Apply per-sample transforms to it (with or without pseudo batch dim)

3.) Collate to batch

4.) Apply batch transforms

5.) Apply GPU transforms

Steps 1.-4. can be executed in a multiprocessing environment. If this is the case, the results will be synced back to the main process before applying GPU transforms.

Training Logic

The only remaining step is now to integrate this to the training logic of PyTorchLightning.

The only things we did not yet discuss is how to setup optimizers, logging and train/validation step.

The optimizer setup is done by a function configure_optimizers that should return the created optimizers.

Logging can either be done automatically (all values for the key log in the dict returned from validation_epoch_end and training_epoch_end will autoamtically be logged) or manually (explicitly calling the logger in any of these functions). We show both examples below.

For setting up the actual training logic we need to specify training_step (and validation_step for validation). The complete example is below:

[ ]:
from rising.transforms import NormZeroMeanUnitStd
from rising.loading import DataLoader
import torch
from tqdm import tqdm

class TrainableUNet(Unet):
    """A trainable UNet (extends the base class by training logic)"""
    def __init__(self, hparams: Optional[dict] = None):
            hparams: the hyperparameters needed to construct and train the network.
                Specifically these are:
                * start_filts (int)
                * depth (int)
                * in_channels (int)
                * num_classes (int)
                * min_scale (float)
                * max_scale (float)
                * min_rotation (int, float)
                * max_rotation (int, float)
                * batch_size (int)
                * num_workers(int)
                * learning_rate (float)

                For all of them exist usable default parameters.
        if hparams is None:
            hparams = {}

        # define loss functions
        self.dice_loss = SoftDiceLoss(weight=[0., 1.])
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def train_dataloader(self) -> DataLoader:
        Specifies the train dataloader

            DataLoader: the train dataloader
        # construct dataset
        dataset = NiiDataset(train=True, data_dir=data_dir)

        # specify batch transforms
        batch_transforms = Compose([
            RandomAffine(scale_range=(self.hparams.get('min_scale', 0.9), self.hparams.get('max_scale', 1.1)),
                         rotation_range=(self.hparams.get('min_rotation', -10), self.hparams.get('max_rotation', 10)),
                        keys=('data', 'label')),

        # construct loader
        dataloader = DataLoader(dataset,
                                batch_size=self.hparams.get('batch_size', 1),
                                num_workers=self.hparams.get('num_workers', 4))
        return dataloader

    def val_dataloader(self) -> DataLoader:
        # construct dataset
        dataset = NiiDataset(train=False, data_dir=data_dir)

        # specify batch transforms (no augmentation here)
        batch_transforms = NormZeroMeanUnitStd(keys=('data',))

        # construct loader
        dataloader = DataLoader(dataset,
                                batch_size=self.hparams.get('batch_size', 1),
                                num_workers=self.hparams.get('num_workers', 4))

        return dataloader

    def configure_optimizers(self) -> torch.optim.Optimizer:
        Configures the optimier to use for training

            torch.optim.Optimier: the optimizer for updating the model's parameters
        return torch.optim.Adam(self.parameters(), lr=self.hparams.get('learning_rate', 1e-3))

    def training_step(self, batch: dict, batch_idx: int) -> dict:
        Defines the training logic

            batch: contains the data (inputs and ground truth)
            batch_idx: the number of the current batch

            dict: the current loss value
        x, y = batch['data'], batch['label']

        # remove channel dim from gt (was necessary for augmentation)
        y = y[:, 0].long()

        # obtain predictions
        pred = self(x)
        softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)

        # Calculate losses
        ce_loss = self.ce_loss(pred, y)
        dice_loss = self.dice_loss(softmaxed_pred, y)
        total_loss = (ce_loss + dice_loss) / 2

        # calculate dice coefficient
        dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)

        # log values
        self.logger.experiment.add_scalar('Train/DiceCoeff', dice_coeff)
        self.logger.experiment.add_scalar('Train/CE', ce_loss)
        self.logger.experiment.add_scalar('Train/SoftDiceLoss', dice_loss)
        self.logger.experiment.add_scalar('Train/TotalLoss', total_loss)

        return {'loss': total_loss}

    def validation_step(self, batch: dict, batch_idx: int) -> dict:
        Defines the validation logic

            batch: contains the data (inputs and ground truth)
            batch_idx: the number of the current batch

            dict: the current loss and metric values
        x, y = batch['data'], batch['label']

        # remove channel dim from gt (was necessary for augmentation)
        y = y[:, 0].long()

        # obtain predictions
        pred = self(x)
        softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)

        # calculate losses
        ce_loss = self.ce_loss(pred, y)
        dice_loss = self.dice_loss(softmaxed_pred, y)
        total_loss = (ce_loss + dice_loss) / 2

        # calculate dice coefficient
        dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)

        # log values
        self.logger.experiment.add_scalar('Val/DiceCoeff', dice_coeff)
        self.logger.experiment.add_scalar('Val/CE', ce_loss)
        self.logger.experiment.add_scalar('Val/SoftDiceLoss', dice_loss)
        self.logger.experiment.add_scalar('Val/TotalLoss', total_loss)

        return {'val_loss': total_loss, 'dice': dice_coeff}

    def validation_epoch_end(self, outputs: list) -> dict:
        """Aggregates data from each validation step

            outputs: the returned values from each validation step

            dict: the aggregated outputs
        mean_outputs = {}
        for k in outputs[0].keys():
            mean_outputs[k] = torch.stack([x[k] for x in outputs]).mean()

        tqdm.write('Dice: \t%.3f' % mean_outputs['dice'].item())
        return mean_outputs

Most of this stuff is relevant for PyTorch Lightning. But the dataloader setup nicely shows the integration of rising with any existing framework working on PyTorch Dataloaders (like PyTorch Lightning or PyTorch Ignite) for batched and sample transforms.


We’ve finally finished all the pipeline definition. Now let’s just load the tensorboard extension to monitor our training. For this we will define a common output dir for lightning:

[ ]:
output_dir = 'logs'
os.makedirs(output_dir, exist_ok=True)
[ ]:
# Start tensorboard.

%reload_ext tensorboard
%tensorboard --logdir {output_dir}

And now it’s finally time to train!

On a GPU in colab, training takes approximately 40 seconds per epoch, which is a total of 33 minutes (2000 seconds) for training, if early stopping doesn’t kick in. For me it kicks in after 25 Epochs which takes around 16 Minutes on a colab GPU

[ ]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer

early_stop_callback = EarlyStopping(monitor='dice', min_delta=0.001, patience=10, verbose=False, mode='max')

if torch.cuda.is_available():
    gpus = 1
    gpus = None

nb_epochs = 50
num_start_filts = 16
num_workers = 4

if 'CI' in os.environ:
    nb_epochs = 1
    num_start_filts = 2
    num_workers = 0

model = TrainableUNet({'start_filts': num_start_filts, 'num_workers': num_workers})

trainer = Trainer(gpus=gpus, default_save_path=output_dir, early_stop_callback=early_stop_callback, max_nb_epochs=nb_epochs)

In the end, you should see a dice coefficient of 0.88 after 25 Epochs.

[ ]:

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 2a580e9c.

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.