
Source code for rising.transforms.spatial

# from __future__ import annotations
import torch

from itertools import combinations
from typing import Union, Sequence, Callable, Optional
from torch.multiprocessing import Value

from rising.random import AbstractParameter, DiscreteParameter
from rising.transforms.abstract import AbstractTransform, BaseTransform
from rising.transforms.functional.spatial import *

__all__ = ["Mirror", "Rot90", "ResizeNative",
           "Zoom", "ProgressiveResize", "SizeStepScheduler"]

scheduler_type = Callable[[int], Union[int, Sequence[int]]]

[docs]class Mirror(BaseTransform): """Random mirror transform""" def __init__(self, dims: Union[int, DiscreteParameter, Sequence[Union[int, DiscreteParameter]]], keys: Sequence[str] = ('data',), grad: bool = False, **kwargs): """ Args: dims: axes which should be mirrored keys: keys which should be mirrored prob: probability for mirror. If float value is provided, it is used for all dims grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to superclass Examples: >>> # Use mirror transform for augmentations >>> from rising.random import DiscreteCombinationsParameter >>> # We sample from all possible mirror combination for >>> # volumetric data >>> trafo = Mirror(DiscreteCombinationsParameter((0, 1, 2))) """ super().__init__(augment_fn=mirror, dims=dims, keys=keys, grad=grad, property_names=('dims',), **kwargs)
[docs]class Rot90(AbstractTransform): """Rotate 90 degree around dims""" def __init__(self, dims: Union[Sequence[int], DiscreteParameter], keys: Sequence[str] = ('data',), num_rots: Sequence[int] = (0, 1, 2, 3), prob: float = 0.5, grad: bool = False, **kwargs): """ Args: dims: dims/axis ro rotate. If more than two dims are provided, 2 dimensions are randomly chosen at each call keys: keys which should be rotated num_rots: possible values for number of rotations prob: probability for rotation grad: enable gradient computation inside transformation kwargs: keyword arguments passed to superclass See Also: :func:`torch.Tensor.rot90` """ super().__init__(grad=grad, **kwargs) self.keys = keys self.prob = prob if not isinstance(dims, DiscreteParameter): if len(dims) > 2: dims = list(combinations(dims, 2)) else: dims = (dims,) dims = DiscreteParameter(dims) self.register_sampler("dims", dims) self.register_sampler("num_rots", DiscreteParameter(num_rots))
[docs] def forward(self, **data) -> dict: """ Apply transformation Args: data: dict with tensors Returns: dict: dict with augmented data """ if torch.rand(1) < self.prob: num_rots = self.num_rots rand_dims = self.dims for key in self.keys: data[key] = rot90(data[key], k=num_rots, dims=rand_dims) return data
[docs]class ResizeNative(BaseTransform): """Resize data to given size""" def __init__(self, size: Union[int, Sequence[int]], mode: str = 'nearest', align_corners: Optional[bool] = None, preserve_range: bool = False, keys: Sequence = ('data',), grad: bool = False, **kwargs): """ Args: size: spatial output size (excluding batch size and number of channels) mode: one of ``nearest``, ``linear``, ``bilinear``, ``bicubic``, ``trilinear``, ``area`` (for more inforamtion see :func:`torch.nn.functional.interpolate`) align_corners: input and output tensors are aligned by the center \ points of their corners pixels, preserving the values at the corner pixels. preserve_range: output tensor has same range as input tensor keys: keys which should be augmented grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to augment_fn """ super().__init__(augment_fn=resize_native, size=size, mode=mode, align_corners=align_corners, preserve_range=preserve_range, keys=keys, grad=grad, **kwargs)
[docs]class Zoom(BaseTransform): """Apply augment_fn to keys. By default the scaling factor is sampled from a uniform distribution with the range specified by :attr:`random_args` """ def __init__(self, scale_factor: Union[Sequence, AbstractParameter] = (0.75, 1.25), mode: str = 'nearest', align_corners: bool = None, preserve_range: bool = False, keys: Sequence = ('data',), grad: bool = False, **kwargs): """ Args: scale_factor: positional arguments passed for random function. If Sequence[Sequence] is provided, a random value for each item in the outer Sequence is generated. This can be used to set different ranges for different axis. mode: one of `nearest`, `linear`, `bilinear`, `bicubic`, `trilinear`, `area` (for more inforamtion see :func:`torch.nn.functional.interpolate`) align_corners: input and output tensors are aligned by the center points of their corners pixels, preserving the values at the corner pixels. preserve_range: output tensor has same range as input tensor keys: keys which should be augmented grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to augment_fn See Also: :func:`random.uniform`, :func:`torch.nn.functional.interpolate` """ super().__init__(augment_fn=resize_native, scale_factor=scale_factor, mode=mode, align_corners=align_corners, preserve_range=preserve_range, keys=keys, grad=grad, property_names=('scale_factor',), **kwargs)
[docs]class ProgressiveResize(ResizeNative): """Resize data to sizes specified by scheduler""" def __init__(self, scheduler: scheduler_type, mode: str = 'nearest', align_corners: bool = None, preserve_range: bool = False, keys: Sequence = ('data',), grad: bool = False, **kwargs): """ Args: scheduler: scheduler which determined the current size. The scheduler is called with the current iteration of the transform mode: one of ``nearest``, ``linear``, ``bilinear``, ``bicubic``, ``trilinear``, ``area`` (for more inforamtion see :func:`torch.nn.functional.interpolate`) align_corners: input and output tensors are aligned by the center points of their corners pixels, preserving the values at the corner pixels. preserve_range: output tensor has same range as input tensor keys: keys which should be augmented grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to augment_fn Warnings: When this transformations is used in combination with multiprocessing, the step counter is not perfectly synchronized between multiple processes. As a result the step count my jump between values in a range of the number of processes used. """ super().__init__(size=0, mode=mode, align_corners=align_corners, preserve_range=preserve_range, keys=keys, grad=grad, **kwargs) self.scheduler = scheduler self._step = Value('i', 0)
[docs] def reset_step(self) -> ResizeNative: """ Reset step to 0 Returns: ResizeNative: returns self to allow chaining """ with self._step.get_lock(): self._step.value = 0 return self
[docs] def increment(self) -> ResizeNative: """ Increment step by 1 Returns: ResizeNative: returns self to allow chaining """ with self._step.get_lock(): self._step.value += 1 return self
@property def step(self) -> int: """ Current step Returns: int: number of steps """ return self._step.value
[docs] def forward(self, **data) -> dict: """ Resize data Args: **data: input batch Returns: dict: augmented batch """ self.kwargs["size"] = self.scheduler(self.step) self.increment() return super().forward(**data)
[docs]class SizeStepScheduler: """Scheduler return size when milestone is reached""" def __init__(self, milestones: Sequence[int], sizes: Union[Sequence[int], Sequence[Sequence[int]]]): """ Args: milestones: contains number of iterations where size should be changed sizes: sizes corresponding to milestones """ if len(milestones) != len(sizes) - 1: raise TypeError("Sizes must include initial size and thus " "has one element more than miltstones.") self.targets = sorted(zip((0, *milestones), sizes), key=lambda x: x[0], reverse=True)
[docs] def __call__(self, step) -> Union[int, Sequence[int], Sequence[Sequence[int]]]: """ Return size with regard to milestones Args: step: current step Returns: Union[int, Sequence[int], Sequence[Sequence[int]]]: current size """ for t in self.targets: if step >= t[0]: return t[1] return self.targets[-1][1]

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 871c839e.

Read the Docs v: v0.2.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.