Segmentation with rising
and PytorchLightning
¶
This example will show you how to build a proper training pipeline with PyTorch Lightning and rising. But first let’s configure this notebook correctly:
[ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
and install all our dependencies:
[ ]:
!pip install --upgrade --quiet pytorch-lightning # for training
!pip install --upgrade --quiet git+https://github.com/PhoenixDL/rising # for data handling
!pip install --upgrade --quiet SimpleITK # for loading medical data
!pip install --upgrade --quiet tensorboard # for monitoring training
!pip install --upgrade --quiet gdown # to download data cross platform
Data¶
Once this is done, we need to take care of our training data. To show risings full capabilities, we will be using 3D data from medical decathlon (specifically Task 4: Hippocampus).
Download¶
We will use the data provided on Google Drive and download it:
[ ]:
import os
import SimpleITK as sitk
import json
import tempfile
import numpy as np
import tarfile
import time
import gdown
temp_dir = tempfile.mkdtemp()
# generate dummy data for ci/cd
if 'CI' in os.environ:
data_dir = os.path.join(temp_dir, 'DummyData')
os.makedirs(os.path.join(data_dir, 'training'), exist_ok=True)
data_paths = []
for idx in range(50):
img = np.random.randint(-500, 500, (32, 64, 32), np.int16)
mask = np.random.randint(0, 1, (32, 64, 32), np.int16)
img_path = os.path.join(data_dir, 'training', 'img_%03d.nii.gz' % idx)
mask_path = os.path.join(data_dir, 'training', 'mask_%03d.nii.gz' % idx)
sitk.WriteImage(sitk.GetImageFromArray(img), img_path)
sitk.WriteImage(sitk.GetImageFromArray(mask), mask_path)
data_paths.append({'image': img_path, 'label': mask_path})
with open(os.path.join(data_dir, 'dataset.json'), 'w') as f:
json.dump({'training': data_paths}, f, sort_keys=True, indent=4)
else:
data_url = "https://drive.google.com/uc?export=download&id=1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C"
data_dir = os.path.join(temp_dir, 'Task04_Hippocampus')
download_path = os.path.join(temp_dir, 'data.tar')
gdown.download(data_url, download_path)
tarfile.TarFile(download_path).extractall(temp_dir)
Great! We got our data. Now we can work on loading it. For loading data, rising
follows the same principle as PyTorch: It separates the dataset, which provides the logic of loading a single sample, from the dataloader for automatted handling of parallel loading and batching.
In fact we at rising
thought that there is no need to reinvent the wheel. This is why we internally use PyTorch’s data structure and just extend it a bit. We’ll come to these extensions later.
Dataset¶
Our dataset is fairly simple. It just loads the Nifti Data we downloaded before and returns each sample as a dict:
[ ]:
import SimpleITK as sitk
import json
from rising import loading
from rising.loading import Dataset
import torch
class NiiDataset(Dataset):
def __init__(self, train: bool, data_dir: str):
"""
Args:
train: whether to use the training or the validation split
data_dir: directory containing the data
"""
with open(os.path.join(data_dir, 'dataset.json')) as f:
content = json.load(f)['training']
num_train_samples = int(len(content) * 0.9)
# Split train data into training and validation,
# since test data contains no ground truth
if train:
data = content[:num_train_samples]
else:
data = content[num_train_samples:]
self.data = data
self.data_dir = data_dir
def __getitem__(self, item: int) -> dict:
"""
Loads and Returns a single sample
Args:
item: index specifying which item to load
Returns:
dict: the loaded sample
"""
sample = self.data[item]
img = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(self.data_dir, sample['image'])))
# add channel dim if necesary
if img.ndim == 3:
img = img[None]
label = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(self.data_dir, sample['label'])))
# convert multiclass to binary task by combining all positives
label = label > 0
# add channel dim if necessary
if label.ndim == 3:
label = label[None]
return {'data': torch.from_numpy(img).float(),
'label': torch.from_numpy(label).float()}
def __len__(self) -> int:
"""
Adds a length to the dataset
Returns:
int: dataset's length
"""
return len(self.data)
For compatibility each rising
dataset must hold the same attributes as a PyTorch
dataset. This basically comes down to be indexeable. This means, each Sequence-like data (e.g. lists, tuples, tensors or arrays) could also directly be used as a dataset. Ideally each dataset also has a length, since the dataloader tries to use this length to calculate/estimate its own length.
Integration With PyTorch Lightning: Model and Training¶
After obtaining our data and implementing a way to load it, we now need a model we can train. For this, we will use a fairly simple implementation of the U-Net, which basically is an encoder-decoder network with skip connections. In Lightning
all modules should be derived from a LightningModule
, which itself is a subclass of the torch.nn.Module
. For further details on the LightningModule
please refer to the
project itself or it’s documentation.
Model¶
For now we will only define the network’s logic and omit the training logic, which we’ll add later.
[ ]:
import pytorch_lightning as pl
import torch
class Unet(pl.LightningModule):
"""Simple U-Net without training logic"""
def __init__(self, hparams: dict):
"""
Args:
hparams: the hyperparameters needed to construct the network.
Specifically these are:
* start_filts (int)
* depth (int)
* in_channels (int)
* num_classes (int)
"""
super().__init__()
# 4 downsample layers
out_filts = hparams.get('start_filts', 16)
depth = hparams.get('depth', 3)
in_filts = hparams.get('in_channels', 1)
num_classes = hparams.get('num_classes', 2)
for idx in range(depth):
down_block = torch.nn.Sequential(
torch.nn.Conv3d(in_filts, out_filts, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
in_filts = out_filts
out_filts *= 2
setattr(self, 'down_block_%d' % idx, down_block)
out_filts = out_filts // 2
in_filts = in_filts // 2
out_filts, in_filts = in_filts, out_filts
for idx in range(depth-1):
up_block = torch.nn.Sequential(
torch.nn.Conv3d(in_filts + out_filts, out_filts, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
in_filts = out_filts
out_filts = out_filts // 2
setattr(self, 'up_block_%d' % idx, up_block)
self.final_conv = torch.nn.Conv3d(in_filts, num_classes, kernel_size=1)
self.max_pool = torch.nn.MaxPool3d(2, stride=2)
self.up_sample = torch.nn.Upsample(scale_factor=2)
self.hparams = hparams
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forwards the :attr`input_tensor` through the network to obtain a prediction
Args:
input_tensor: the network's input
Returns:
torch.Tensor: the networks output given the :attr`input_tensor`
"""
depth = self.hparams.get('depth', 3)
intermediate_outputs = []
# Compute all the encoder blocks' outputs
for idx in range(depth):
intermed = getattr(self, 'down_block_%d' % idx)(input_tensor)
if idx < depth - 1:
# store intermediate values for usage in decoder
intermediate_outputs.append(intermed)
input_tensor = getattr(self, 'max_pool')(intermed)
else:
input_tensor = intermed
# Compute all the decoder blocks' outputs
for idx in range(depth-1):
input_tensor = getattr(self, 'up_sample')(input_tensor)
# use intermediate values from encoder
from_down = intermediate_outputs.pop(-1)
intermed = torch.cat([input_tensor, from_down], dim=1)
input_tensor = getattr(self, 'up_block_%d' % idx)(intermed)
return getattr(self, 'final_conv')(input_tensor)
Okay, that was easy, right? Now let’s just check if everything in our network is fine:
[ ]:
net = Unet({'num_classes': 2, 'in_channels': 1, 'depth': 2, 'start_filts': 2})
print(net(torch.rand(1, 1, 16, 16, 16)).shape)
So what did we do here? We initialized a network accepting input images with one channel. This network will then predict a segmentation map for 2 classes (of which one is the background class). It does so with 3 resolution stages.
When we tested the network, we forwarded a tensor with random values of size (1, 1, 16, 16, 16)
through it. The first 1
here is the batch dim, the second 1
the channel dim (as we specified one input channel) and the three 16
are the spatial dimension (depth, height and width).
The output has the same dimensons except the channel dimension now holding 2
channels (one per class).
Training Criterions and Metrics¶
For training we will use the combination of CrossEntropyLoss and the SoftDiceLoss (see below).
For more details on this, I’d recommend Jeremy Jordan’s Blog on semantic segmentation.
[ ]:
import rising
from typing import Sequence, Optional, Union
import torch
# Taken from https://github.com/justusschock/dl-utils/blob/master/dlutils/losses/soft_dice.py
class SoftDiceLoss(torch.nn.Module):
"""Soft Dice Loss"""
def __init__(self, square_nom: bool = False,
square_denom: bool = False,
weight: Optional[Union[Sequence, torch.Tensor]] = None,
smooth: float = 1.):
"""
Args:
square_nom: whether to square the nominator
square_denom: whether to square the denominator
weight: additional weighting of individual classes
smooth: smoothing for nominator and denominator
"""
super().__init__()
self.square_nom = square_nom
self.square_denom = square_denom
self.smooth = smooth
if weight is not None:
if not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)
self.register_buffer("weight", weight)
else:
self.weight = None
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Computes SoftDice Loss
Args:
predictions: the predictions obtained by the network
targets: the targets (ground truth) for the :attr:`predictions`
Returns:
torch.Tensor: the computed loss value
"""
# number of classes for onehot
n_classes = predictions.shape[1]
with torch.no_grad():
targets_onehot = rising.transforms.functional.channel.one_hot_batch(
targets.unsqueeze(1), num_classes=n_classes)
# sum over spatial dimensions
dims = tuple(range(2, predictions.dim()))
# compute nominator
if self.square_nom:
nom = torch.sum((predictions * targets_onehot.float()) ** 2, dim=dims)
else:
nom = torch.sum(predictions * targets_onehot.float(), dim=dims)
nom = 2 * nom + self.smooth
# compute denominator
if self.square_denom:
i_sum = torch.sum(predictions ** 2, dim=dims)
t_sum = torch.sum(targets_onehot ** 2, dim=dims)
else:
i_sum = torch.sum(predictions, dim=dims)
t_sum = torch.sum(targets_onehot, dim=dims)
denom = i_sum + t_sum.float() + self.smooth
# compute loss
frac = nom / denom
# apply weight for individual classesproperly
if self.weight is not None:
frac = self.weight * frac
# average over classes
frac = - torch.mean(frac, dim=1)
return frac
Okay, now that we are able to properly calculate the loss function, we still lack a metric to monitor, that describes our performance. For segmentation tasks, this usually comes down to the dice coefficient. So let’s implement this one as well:
[ ]:
# Taken from https://github.com/justusschock/dl-utils/blob/master/dlutils/metrics/dice.py
def binary_dice_coefficient(pred: torch.Tensor, gt: torch.Tensor,
thresh: float = 0.5, smooth: float = 1e-7) -> torch.Tensor:
"""
computes the dice coefficient for a binary segmentation task
Args:
pred: predicted segmentation (of shape Nx(Dx)HxW)
gt: target segmentation (of shape NxCx(Dx)HxW)
thresh: segmentation threshold
smooth: smoothing value to avoid division by zero
Returns:
torch.Tensor: dice score
"""
assert pred.shape == gt.shape
pred_bool = pred > thresh
intersec = (pred_bool * gt).float()
return 2 * intersec.sum() / (pred_bool.float().sum()
+ gt.float().sum() + smooth)
Neat! So far we defined all criterions and metrics necessary for proper training and monitoring. But there are still two major parts of our pipeline missing:
1.) Data Preprocessing and Augmentation
2.) what to do for parameter update
Let’s deal with the first point now.
Data Preprocessing¶
Since all samples in our dataset are of different size, we cannot collate them to a batch directly. Instead we need to resize them. Frameworks like torchvision do this inside the dataset. With rising
however, we opted for moving this part outside the dataset (but still apply it on each sample separately before batching) for these reasons.
1.) The dataset get’s more reusable for different settings
2.) The transforms don’t have to be implemented into each dataset, which means it is easier to switch datasets without code duplication
3.) Applying different transforms is as easy as changing an argument of the loader; no need to deal with this manually in the dataset
This kind of transforms kann be passed to the dataloader with sample_transforms
. If you have an implementation that usually works on batched data, we got you. All you need to do is specifying pseudo_batch_dim
and we will take care of the rest. We will then automatically add a pseudo batch dim to all kind of data (tensors, arrays and all kind of built-in python containers containing a mixture thereof) before applying these transforms and remove it afterwards.
For now, we use our batched implementation of native torch resizing:
[ ]:
from rising.transforms import Compose, ResizeNative
def common_per_sample_trafos():
return Compose(ResizeNative(size=(32, 64, 32), keys=('data',), mode='trilinear'),
ResizeNative(size=(32, 64, 32), keys=('label',), mode='nearest'))
Data Augmentation¶
Now that we have defined our preprocessing, let’s come to data augmentation. To enrich our dataset, we randomly apply an affine. While rising
already contains an implementation of Affine transforms that can also handle random inputs pretty well, we will implement a basic random parameter sampling by ourselves, since this also serves as an educational example.
Basically this is really straight forward. We just derive the BaseAffine
class, overwrite the way the matrix is assembled by adding the sampling before we call the actual assembly method. We leave the rest to the already defined class:
[ ]:
from rising.transforms.affine import BaseAffine
import random
from typing import Optional, Sequence
class RandomAffine(BaseAffine):
"""Base Affine with random parameters for scale, rotation and translation"""
def __init__(self, scale_range: Optional[tuple] = None,
rotation_range: Optional[tuple] = None,
translation_range: Optional[tuple] = None,
degree: bool = True,
image_transform: bool = True,
keys: Sequence = ('data',),
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: str = 'nearest',
padding_mode: str = 'zeros',
align_corners: bool = False,
reverse_order: bool = False,
**kwargs,):
"""
Args:
scale_range: tuple containing minimum and maximum values for scale.
Actual values will be sampled from uniform distribution with these
constraints.
rotation_range: tuple containing minimum and maximum values for rotation.
Actual values will be sampled from uniform distribution with these
constraints.
translation_range: tuple containing minimum and maximum values for translation.
Actual values will be sampled from uniform distribution with these
constraints.
keys: keys which should be augmented
grad: enable gradient computation inside transformation
degree: whether the given rotation(s) are in degrees.
Only valid for rotation parameters, which aren't passed
as full transformation matrix.
output_size: if given, this will be the resulting image size.
Defaults to ``None``
adjust_size: if True, the resulting image size will be
calculated dynamically to ensure that the whole image fits.
interpolation_mode: interpolation mode to calculate output values
``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'``
padding_mode: padding mode for outside grid values
``'zeros'`` | ``'border'`` | ``'reflection'``.
Default: ``'zeros'``
align_corners: Geometrically, we consider the pixels of the
input as squares rather than points. If set to True,
the extrema (-1 and 1) are considered as referring to the
center points of the input’s corner pixels. If set to False,
they are instead considered as referring to the corner points
of the input’s corner pixels, making the sampling more
resolution agnostic.
reverse_order: reverses the coordinate order of the
transformation to conform to the pytorch convention:
transformation params order [W,H(,D)] and
batch order [(D,)H,W]
**kwargs: additional keyword arguments passed to the
affine transf
"""
super().__init__(scale=None, rotation=None, translation=None,
degree=degree,
image_transform=image_transform,
keys=keys,
grad=grad,
output_size=output_size,
adjust_size=adjust_size,
interpolation_mode=interpolation_mode,
padding_mode=padding_mode,
align_corners=align_corners,
reverse_order=reverse_order,
**kwargs)
self.scale_range = scale_range
self.rotation_range = rotation_range
self.translation_range = translation_range
def assemble_matrix(self, **data) -> torch.Tensor:
"""
Samples Parameters for scale, rotation and translation
before actual matrix assembly.
Args:
**data: dictionary containing a batch
Returns:
torch.Tensor: assembled affine matrix
"""
ndim = data[self.keys[0]].ndim - 2
if self.scale_range is not None:
self.scale = [random.uniform(*self.scale_range) for _ in range(ndim)]
if self.translation_range is not None:
self.translation = [random.uniform(*self.translation_range) for _ in range(ndim)]
if self.rotation_range is not None:
if ndim == 3:
self.rotation = [random.uniform(*self.rotation_range) for _ in range(ndim)]
elif ndim == 1:
self.rotation = random.uniform(*self.rotation_range)
return super().assemble_matrix(**data)
Also not that hard… So, now we have a custom implementation of a randomly parametrized affine transformation. This is all we will use as data augmentation for now.
Batched Transforms that shall be executed on CPU in a multiprocessed way should be specified to the dataloader as batch_transforms
. If they should be executed on GPU, you can pass them as gpu_transforms
. Unfortnuately it is not possible to add GPU transforms in a multiprocessing environment. Thus the internal computation order is like this:
1.) Extract sample from dataset
2.) Apply per-sample transforms to it (with or without pseudo batch dim)
3.) Collate to batch
4.) Apply batch transforms
5.) Apply GPU transforms
Steps 1.-4. can be executed in a multiprocessing environment. If this is the case, the results will be synced back to the main process before applying GPU transforms.
Training Logic¶
The only remaining step is now to integrate this to the training logic of PyTorchLightning
.
The only things we did not yet discuss is how to setup optimizers, logging and train/validation step.
The optimizer setup is done by a function configure_optimizers
that should return the created optimizers.
Logging can either be done automatically (all values for the key log
in the dict returned from validation_epoch_end
and training_epoch_end
will autoamtically be logged) or manually (explicitly calling the logger in any of these functions). We show both examples below.
For setting up the actual training logic we need to specify training_step
(and validation_step
for validation). The complete example is below:
[ ]:
from rising.transforms import NormZeroMeanUnitStd
from rising.loading import DataLoader
import torch
from tqdm import tqdm
class TrainableUNet(Unet):
"""A trainable UNet (extends the base class by training logic)"""
def __init__(self, hparams: Optional[dict] = None):
"""
Args:
hparams: the hyperparameters needed to construct and train the network.
Specifically these are:
* start_filts (int)
* depth (int)
* in_channels (int)
* num_classes (int)
* min_scale (float)
* max_scale (float)
* min_rotation (int, float)
* max_rotation (int, float)
* batch_size (int)
* num_workers(int)
* learning_rate (float)
For all of them exist usable default parameters.
"""
if hparams is None:
hparams = {}
super().__init__(hparams)
# define loss functions
self.dice_loss = SoftDiceLoss(weight=[0., 1.])
self.ce_loss = torch.nn.CrossEntropyLoss()
def train_dataloader(self) -> DataLoader:
"""
Specifies the train dataloader
Returns:
DataLoader: the train dataloader
"""
# construct dataset
dataset = NiiDataset(train=True, data_dir=data_dir)
# specify batch transforms
batch_transforms = Compose([
RandomAffine(scale_range=(self.hparams.get('min_scale', 0.9), self.hparams.get('max_scale', 1.1)),
rotation_range=(self.hparams.get('min_rotation', -10), self.hparams.get('max_rotation', 10)),
keys=('data', 'label')),
NormZeroMeanUnitStd(keys=('data',))
])
# construct loader
dataloader = DataLoader(dataset,
batch_size=self.hparams.get('batch_size', 1),
batch_transforms=batch_transforms,
shuffle=True,
sample_transforms=common_per_sample_trafos(),
pseudo_batch_dim=True,
num_workers=self.hparams.get('num_workers', 4))
return dataloader
def val_dataloader(self) -> DataLoader:
# construct dataset
dataset = NiiDataset(train=False, data_dir=data_dir)
# specify batch transforms (no augmentation here)
batch_transforms = NormZeroMeanUnitStd(keys=('data',))
# construct loader
dataloader = DataLoader(dataset,
batch_size=self.hparams.get('batch_size', 1),
batch_transforms=batch_transforms,
shuffle=False,
sample_transforms=common_per_sample_trafos(),
pseudo_batch_dim=True,
num_workers=self.hparams.get('num_workers', 4))
return dataloader
def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Configures the optimier to use for training
Returns:
torch.optim.Optimier: the optimizer for updating the model's parameters
"""
return torch.optim.Adam(self.parameters(), lr=self.hparams.get('learning_rate', 1e-3))
def training_step(self, batch: dict, batch_idx: int) -> dict:
"""
Defines the training logic
Args:
batch: contains the data (inputs and ground truth)
batch_idx: the number of the current batch
Returns:
dict: the current loss value
"""
x, y = batch['data'], batch['label']
# remove channel dim from gt (was necessary for augmentation)
y = y[:, 0].long()
# obtain predictions
pred = self(x)
softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)
# Calculate losses
ce_loss = self.ce_loss(pred, y)
dice_loss = self.dice_loss(softmaxed_pred, y)
total_loss = (ce_loss + dice_loss) / 2
# calculate dice coefficient
dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)
# log values
self.logger.experiment.add_scalar('Train/DiceCoeff', dice_coeff)
self.logger.experiment.add_scalar('Train/CE', ce_loss)
self.logger.experiment.add_scalar('Train/SoftDiceLoss', dice_loss)
self.logger.experiment.add_scalar('Train/TotalLoss', total_loss)
return {'loss': total_loss}
def validation_step(self, batch: dict, batch_idx: int) -> dict:
"""
Defines the validation logic
Args:
batch: contains the data (inputs and ground truth)
batch_idx: the number of the current batch
Returns:
dict: the current loss and metric values
"""
x, y = batch['data'], batch['label']
# remove channel dim from gt (was necessary for augmentation)
y = y[:, 0].long()
# obtain predictions
pred = self(x)
softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)
# calculate losses
ce_loss = self.ce_loss(pred, y)
dice_loss = self.dice_loss(softmaxed_pred, y)
total_loss = (ce_loss + dice_loss) / 2
# calculate dice coefficient
dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)
# log values
self.logger.experiment.add_scalar('Val/DiceCoeff', dice_coeff)
self.logger.experiment.add_scalar('Val/CE', ce_loss)
self.logger.experiment.add_scalar('Val/SoftDiceLoss', dice_loss)
self.logger.experiment.add_scalar('Val/TotalLoss', total_loss)
return {'val_loss': total_loss, 'dice': dice_coeff}
def validation_epoch_end(self, outputs: list) -> dict:
"""Aggregates data from each validation step
Args:
outputs: the returned values from each validation step
Returns:
dict: the aggregated outputs
"""
mean_outputs = {}
for k in outputs[0].keys():
mean_outputs[k] = torch.stack([x[k] for x in outputs]).mean()
tqdm.write('Dice: \t%.3f' % mean_outputs['dice'].item())
return mean_outputs
Most of this stuff is relevant for PyTorch Lightning
. But the dataloader setup nicely shows the integration of rising
with any existing framework working on PyTorch Dataloaders
(like PyTorch Lightning
or PyTorch Ignite
) for batched and sample transforms.
Training¶
We’ve finally finished all the pipeline definition. Now let’s just load the tensorboard extension to monitor our training. For this we will define a common output dir for lightning:
[ ]:
output_dir = 'logs'
os.makedirs(output_dir, exist_ok=True)
[ ]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir {output_dir}
And now it’s finally time to train!
On a GPU in colab, training takes approximately 40 seconds per epoch, which is a total of 33 minutes (2000 seconds) for training, if early stopping doesn’t kick in. For me it kicks in after 25 Epochs which takes around 16 Minutes on a colab GPU
[ ]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer
early_stop_callback = EarlyStopping(monitor='dice', min_delta=0.001, patience=10, verbose=False, mode='max')
if torch.cuda.is_available():
gpus = 1
else:
gpus = None
nb_epochs = 50
num_start_filts = 16
num_workers = 4
if 'CI' in os.environ:
nb_epochs = 1
num_start_filts = 2
num_workers = 0
model = TrainableUNet({'start_filts': num_start_filts, 'num_workers': num_workers})
trainer = Trainer(gpus=gpus, default_save_path=output_dir, early_stop_callback=early_stop_callback, max_nb_epochs=nb_epochs)
trainer.fit(model)
In the end, you should see a dice coefficient of 0.88 after 25 Epochs.
[ ]: