
Source code for rising.transforms.abstract

from typing import Any, Callable, Sequence, Tuple, Union

import torch

from rising.random import AbstractParameter, DiscreteParameter

__all__ = ["AbstractTransform", "BaseTransform", "PerSampleTransform", "PerChannelTransform", "BaseTransformSeeded"]

augment_callable = Callable[[torch.Tensor], Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]

[docs]class AbstractTransform(torch.nn.Module): """Base class for all transforms""" def __init__(self, grad: bool = False, **kwargs): """ Args: grad: enable gradient computation inside transformation """ super().__init__() self.grad = grad self._registered_samplers = [] for key, item in kwargs.items(): setattr(self, key, item)
[docs] def register_sampler(self, name: str, sampler: Union[Sequence, AbstractParameter], *args, **kwargs): """ Registers a parameter sampler to the transform. Internally a property is created to forward calls to the attribute to calls of the sampler. Args: name : the property name sampler : the sampler. Will be wrapped to a sampler always returning the same element if not already a sampler *args : additional positional arguments (will be forwarded to sampler call) **kwargs : additional keyword arguments (will be forwarded to sampler call) """ self._registered_samplers.append(name) if hasattr(self, name): raise NameError("Name %s already exists" % name) if not isinstance(sampler, (tuple, list)): sampler = [sampler] new_sampler = [] for _sampler in sampler: if not isinstance(_sampler, AbstractParameter): _sampler = DiscreteParameter([_sampler], replacement=True) new_sampler.append(_sampler) sampler = new_sampler def sample(self): """ Sample random values """ sample_result = tuple([_sampler(*args, **kwargs) for _sampler in sampler]) if len(sample_result) == 1: return sample_result[0] else: return sample_result setattr(self, name, property(sample))
def __getattribute__(self, item) -> Any: """ Automatically dereference registered samplers Args: item: name of attribute Returns: Any: attribute """ res = super().__getattribute__(item) if isinstance(res, property) and item in self._registered_samplers: # by first checking the type we reduce the lookup # time for all non property objects return res.__get__(self) else: return res
[docs] def __call__(self, *args, **kwargs) -> Any: """ Call super class with correct torch context Args: *args: forwarded positional arguments **kwargs: forwarded keyword arguments Returns: Any: transformed data """ if self.grad: context = torch.enable_grad() else: context = torch.no_grad() with context: return super().__call__(*args, **kwargs)
[docs] def forward(self, **data) -> dict: """ Implement transform functionality here Args: **data: dict with data Returns: dict: dict with transformed data """ raise NotImplementedError
[docs]class BaseTransform(AbstractTransform): """ Transform to apply a functional interface to given keys .. warning:: This transform should not be used with functions which have randomness build in because it will result in different augmentations per key. """ def __init__( self, augment_fn: augment_callable, *args, keys: Sequence = ("data",), grad: bool = False, property_names: Sequence[str] = (), **kwargs ): """ Args: augment_fn: function for augmentation *args: positional arguments passed to augment_fn keys: keys which should be augmented grad: enable gradient computation inside transformation property_names: a tuple containing all the properties to call during forward pass **kwargs: keyword arguments passed to augment_fn """ sampler_vals = [kwargs.pop(name) for name in property_names] super().__init__(grad=grad, **kwargs) self.augment_fn = augment_fn self.keys = keys self.property_names = property_names self.args = args self.kwargs = kwargs for name, val in zip(property_names, sampler_vals): self.register_sampler(name, val)
[docs] def forward(self, **data) -> dict: """ Apply transformation Args: data: dict with tensors Returns: dict: dict with augmented data """ kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) for _key in self.keys: data[_key] = self.augment_fn(data[_key], *self.args, **kwargs) return data
[docs]class BaseTransformSeeded(BaseTransform): """ Transform to apply a functional interface to given keys and use the same pytorch(!) seed for every key. """
[docs] def forward(self, **data) -> dict: """ Apply transformation and use same seed for every key Args: data: dict with tensors Returns: dict: dict with augmented data """ kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) seed = torch.random.get_rng_state() for _key in self.keys: torch.random.set_rng_state(seed) data[_key] = self.augment_fn(data[_key], *self.args, **kwargs) return data
[docs]class PerSampleTransform(BaseTransform): """ Apply transformation to each sample in batch individually :attr:`augment_fn` must be callable with option :attr:`out` where results are saved in. .. warning:: This transform should not be used with functions which have randomness build in because it will result in different augmentations per sample and key. """
[docs] def forward(self, **data) -> dict: """ Args: data: dict with tensors Returns: dict: dict with augmented data """ kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) for _key in self.keys: out = torch.empty_like(data[_key]) for _i in range(data[_key].shape[0]): out[_i] = self.augment_fn(data[_key][_i], out=out[_i], **kwargs) data[_key] = out return data
[docs]class PerChannelTransform(BaseTransform): """ Apply transformation per channel (but still to whole batch) .. warning:: This transform should not be used with functions which have randomness build in because it will result in different augmentations per channel and key. """ def __init__( self, augment_fn: augment_callable, per_channel: bool = False, keys: Sequence = ("data",), grad: bool = False, property_names: Tuple[str] = (), **kwargs ): """ Args: augment_fn: function for augmentation per_channel: enable transformation per channel keys: keys which should be augmented grad: enable gradient computation inside transformation kwargs: keyword arguments passed to augment_fn """ super().__init__(augment_fn=augment_fn, keys=keys, grad=grad, property_names=property_names, **kwargs) self.per_channel = per_channel
[docs] def forward(self, **data) -> dict: """ Apply transformation Args: data: dict with tensors Returns: dict: dict with augmented data """ if self.per_channel: kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) for _key in self.keys: out = torch.empty_like(data[_key]) for _i in range(data[_key].shape[1]): out[:, _i] = self.augment_fn(data[_key][:, _i], out=out[:, _i], **kwargs) data[_key] = out return data else: return super().forward(**data)

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision b9cd7e8f.

Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.