Shortcuts

Source code for rising.transforms.compose

from random import shuffle
from typing import Sequence, Union, Callable, Any, Mapping, Optional

import torch

from rising.utils import check_scalar
from rising.transforms import AbstractTransform
from rising.random import ContinuousParameter, UniformParameter


__all__ = ["Compose", "DropoutCompose", "OneOf"]


[docs]def dict_call(batch: dict, transform: Callable) -> Any: """ Unpacks the dict for every transformation Args: batch: current batch which is passed to transform transform: transform to perform Returns: Any: transformed batch """ return transform(**batch)
class _TransformWrapper(torch.nn.Module): """ Helper Class to wrap all non-module transforms into modules to use the torch.nn.ModuleList as container for the transforms. This enables forwarding of all model specific calls as ``.to()`` to all transforms """ def __init__(self, trafo: Callable): """ Args: trafo: the actual transform, which will be wrapped by this class. Since this transform is no subclass of ``torch.nn.Module``, its internal state won't be affected by module specific calls """ super().__init__() self.trafo = trafo def forward(self, *args, **kwargs) -> Any: """ Forwards calls to this wrapper to the internal transform Args: *args: positional arguments **kwargs: keyword arguments Returns: Any: trafo return """ return self.trafo(*args, **kwargs)
[docs]class Compose(AbstractTransform): """ Compose multiple transforms """ def __init__(self, *transforms: Union[AbstractTransform, Sequence[AbstractTransform]], shuffle: bool = False, transform_call: Callable[[Any, Callable], Any] = dict_call): """ Args: transforms: one or multiple transformations which are applied in consecutive order shuffle: apply transforms in random order transform_call: function which determines how transforms are called. By default Mappings and Sequences are unpacked during the transform. """ super().__init__(grad=True) if len(transforms) > 0 and isinstance(transforms[0], Sequence): transforms = transforms[0] if not transforms: raise ValueError("At least one transformation needs to be selected.") self.transforms = transforms self.transform_call = transform_call self.shuffle = shuffle
[docs] def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]: """ Apply transforms in a consecutive order. Can either handle Sequence like or Mapping like data. Args: *seq_like: data which is unpacked like a Sequence **map_like: data which is unpacked like a dict Returns: Union[Sequence, Mapping]: transformed data """ assert not (seq_like and map_like) assert len(self.transforms) == len(self.transform_order) data = seq_like if seq_like else map_like if self.shuffle: shuffle(self.transform_order) for idx in self.transform_order: data = self.transform_call(data, self.transforms[idx]) return data
@property def transforms(self) -> torch.nn.ModuleList: """ Transforms getter Returns: torch.nn.ModuleList: transforms to compose """ return self._transforms @transforms.setter def transforms(self, transforms: Union[AbstractTransform, Sequence[AbstractTransform]]): """ Transforms setter Args: transforms: one or multiple transformations which are applied in consecutive order """ # make transforms a list to be mutable. # Otherwise the enforced typesetting below might fail. if isinstance(transforms, tuple): transforms = list(transforms) for idx, trafo in enumerate(transforms): if not isinstance(trafo, torch.nn.Module): transforms[idx] = _TransformWrapper(trafo) self._transforms = torch.nn.ModuleList(transforms) self.transform_order = list(range(len(self.transforms))) @property def shuffle(self) -> bool: """ Getter for attribute shuffle Returns: bool: True if shuffle is enabled, False otherwise """ return self._shuffle @shuffle.setter def shuffle(self, shuffle: bool): """ Setter for shuffle Args: shuffle: new status of shuffle """ self._shuffle = shuffle self.transform_order = list(range(len(self.transforms)))
[docs]class DropoutCompose(Compose): """ Compose multiple transforms to one and randomly apply them """ def __init__(self, *transforms: Union[AbstractTransform, Sequence[AbstractTransform]], dropout: Union[float, Sequence[float]] = 0.5, shuffle: bool = False, random_sampler: ContinuousParameter = None, transform_call: Callable[[Any, Callable], Any] = dict_call, **kwargs): """ Args: *transforms: one or multiple transformations which are applied in consecutive order dropout: if provided as float, each transform is skipped with the given probability if :attr:`dropout` is a sequence, it needs to specify the dropout probability for each given transform shuffle: apply transforms in random order random_sampler : a continuous parameter sampler. Samples a random value for each of the transforms. transform_call: function which determines how transforms are called. By default Mappings and Sequences are unpacked during the transform. Raises: ValueError: if dropout is a sequence it must have the same length as transforms """ super().__init__(*transforms, transform_call=transform_call, shuffle=shuffle, **kwargs) if random_sampler is None: random_sampler = UniformParameter(0., 1.) self.register_sampler('prob', random_sampler, size=(len(self.transforms),)) if check_scalar(dropout): dropout = [dropout] * len(self.transforms) self.dropout = dropout if len(dropout) != len(self.transforms): raise TypeError(f"If dropout is a sequence it must specify the " f"dropout probability for each transform, " f"found {len(dropout)} probabilities " f"and {len(self.transforms)} transforms.")
[docs] def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]: """ Apply transforms in a consecutive order. Can either handle Sequence like or Mapping like data. Args: *seq_like: data which is unpacked like a Sequence **map_like: data which is unpacked like a dict Returns: Union[Sequence, Mapping]: dict with transformed data """ assert not (seq_like and map_like) assert len(self.transforms) == len(self.transform_order) data = seq_like if seq_like else map_like rand = self.prob for idx in self.transform_order: if rand[idx] > self.dropout[idx]: data = self.transform_call(data, self.transforms[idx]) return data
[docs]class OneOf(AbstractTransform): """ Apply one of the given transforms. """ def __init__(self, *transforms: Union[AbstractTransform, Sequence[AbstractTransform]], weights: Optional[Sequence[float]] = None, p: float = 1., transform_call: Callable[[Any, Callable], Any] = dict_call): """ Args: *transforms: transforms to choose from weights: additional weights for transforms p: probability that one transform i applied transform_call: function which determines how transforms are called. By default Mappings and Sequences are unpacked during the transform. """ super().__init__(grad=True) if len(transforms) > 0 and isinstance(transforms[0], Sequence): transforms = transforms[0] if not transforms: raise ValueError("At least one transformation needs to be selected.") self.transforms = transforms if weights is not None and len(weights) != len(transforms): raise ValueError("If weights are porvided, every transform needs a weight. " f"Found {len(weights)} weights and {len(transforms)} transforms") if weights is None: self.weights = torch.tensor( [1 / len(self.transforms)] * len(self.transforms)) else: self.weights = torch.tensor(weights) self.p = p self.transform_call = transform_call
[docs] def forward(self, **data) -> dict: if torch.rand(1) < self.p: index = torch.multinomial(self.weights, 1) data = self.transform_call(data, self.transforms[int(index)]) return data

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 871c839e.

Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.