Transformations¶
[ ]:
!pip install napari
!pip install SimpleITK
[ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%gui qt
import os
if 'TEST_ENV' in os.environ:
TEST_ENV = os.environ['TEST_ENV'].lower() == "true"
else:
TEST_ENV = 0
print(f"Running test environment: {bool(TEST_ENV)}")
[ ]:
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen
resp = urlopen("http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip")
zipfile = ZipFile(BytesIO(resp.read()))
img_file = zipfile.extract("ExBox3/T1_brain.nii.gz")
mask_file = zipfile.extract("ExBox3/T1_brain_seg.nii.gz")
[ ]:
import SimpleITK as sitk
import numpy as np
# load image and mask
img_file = "./ExBox3/T1_brain.nii.gz"
mask_file = "./ExBox3/T1_brain_seg.nii.gz"
img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
img = img.astype(np.float32)
mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
mask = mask.astype(np.float32)
assert mask.shape == img.shape
print(f"Image shape {img.shape}")
print(f"Image shape {mask.shape}")
[ ]:
if TEST_ENV:
def view_batch(batch):
pass
else:
%gui qt
import napari
def view_batch(batch):
viewer = napari.view_image(batch["data"].cpu().numpy(), name="data")
viewer.add_image(batch["mask"].cpu().numpy(), name="mask", opacity=0.2)
[ ]:
import torch
from rising.transforms import *
batch = {
"data": torch.from_numpy(img).float()[None, None],
"mask": torch.from_numpy(mask).long()[None, None],
}
def apply_transform(trafo, batch):
transformed = trafo(**batch)
print(f"Transformed data shape: {transformed['data'].shape}")
print(f"Transformed mask shape: {transformed['mask'].shape}")
print(f"Transformed data min: {transformed['data'].min()}")
print(f"Transformed data max: {transformed['data'].max()}")
print(f"Transformed data mean: {transformed['data'].mean()}")
return transformed
[ ]:
print(f"Transformed data shape: {batch['data'].shape}")
print(f"Transformed mask shape: {batch['mask'].shape}")
print(f"Transformed data min: {batch['data'].min()}")
print(f"Transformed data max: {batch['data'].max()}")
print(f"Transformed data mean: {batch['data'].mean()}")
[ ]:
trafo = Scale(1.5, adjust_size=False)
transformed = apply_transform(trafo, batch)
view_batch(transformed)
[ ]:
trafo = Rotate([0, 0, 45], degree=True, adjust_size=False)
transformed = apply_transform(trafo, batch)
view_batch(transformed)
[ ]:
trafo = Translate([0.1, 0, 0], adjust_size=False)
transformed = apply_transform(trafo, batch)
view_batch(transformed)
[ ]: