Source code for rising.transforms.compose
from random import shuffle
from typing import Sequence, Union, Callable, Any, Mapping, Optional
import torch
from rising.utils import check_scalar
from rising.transforms import AbstractTransform
from rising.random import ContinuousParameter, UniformParameter
__all__ = ["Compose", "DropoutCompose", "OneOf"]
[docs]def dict_call(batch: dict, transform: Callable) -> Any:
"""
Unpacks the dict for every transformation
Args:
batch: current batch which is passed to transform
transform: transform to perform
Returns:
Any: transformed batch
"""
return transform(**batch)
class _TransformWrapper(torch.nn.Module):
"""
Helper Class to wrap all non-module transforms into modules to use the
torch.nn.ModuleList as container for the transforms. This enables
forwarding of all model specific calls as ``.to()`` to all transforms
"""
def __init__(self, trafo: Callable):
"""
Args:
trafo: the actual transform, which will be wrapped by this class.
Since this transform is no subclass of ``torch.nn.Module``,
its internal state won't be affected by module specific calls
"""
super().__init__()
self.trafo = trafo
def forward(self, *args, **kwargs) -> Any:
"""
Forwards calls to this wrapper to the internal transform
Args:
*args: positional arguments
**kwargs: keyword arguments
Returns:
Any: trafo return
"""
return self.trafo(*args, **kwargs)
[docs]class Compose(AbstractTransform):
"""
Compose multiple transforms
"""
def __init__(self,
*transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
shuffle: bool = False,
transform_call: Callable[[Any, Callable], Any] = dict_call):
"""
Args:
transforms: one or multiple transformations which are applied
in consecutive order
shuffle: apply transforms in random order
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
"""
super().__init__(grad=True)
if len(transforms) > 0 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
if not transforms:
raise ValueError("At least one transformation needs to be selected.")
self.transforms = transforms
self.transform_call = transform_call
self.shuffle = shuffle
[docs] def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]:
"""
Apply transforms in a consecutive order. Can either handle
Sequence like or Mapping like data.
Args:
*seq_like: data which is unpacked like a Sequence
**map_like: data which is unpacked like a dict
Returns:
Union[Sequence, Mapping]: transformed data
"""
assert not (seq_like and map_like)
assert len(self.transforms) == len(self.transform_order)
data = seq_like if seq_like else map_like
if self.shuffle:
shuffle(self.transform_order)
for idx in self.transform_order:
data = self.transform_call(data, self.transforms[idx])
return data
@property
def transforms(self) -> torch.nn.ModuleList:
"""
Transforms getter
Returns:
torch.nn.ModuleList: transforms to compose
"""
return self._transforms
@transforms.setter
def transforms(self, transforms: Union[AbstractTransform,
Sequence[AbstractTransform]]):
"""
Transforms setter
Args:
transforms: one or multiple transformations which are applied in
consecutive order
"""
# make transforms a list to be mutable.
# Otherwise the enforced typesetting below might fail.
if isinstance(transforms, tuple):
transforms = list(transforms)
for idx, trafo in enumerate(transforms):
if not isinstance(trafo, torch.nn.Module):
transforms[idx] = _TransformWrapper(trafo)
self._transforms = torch.nn.ModuleList(transforms)
self.transform_order = list(range(len(self.transforms)))
@property
def shuffle(self) -> bool:
"""
Getter for attribute shuffle
Returns:
bool: True if shuffle is enabled, False otherwise
"""
return self._shuffle
@shuffle.setter
def shuffle(self, shuffle: bool):
"""
Setter for shuffle
Args:
shuffle: new status of shuffle
"""
self._shuffle = shuffle
self.transform_order = list(range(len(self.transforms)))
[docs]class DropoutCompose(Compose):
"""
Compose multiple transforms to one and randomly apply them
"""
def __init__(self,
*transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
dropout: Union[float, Sequence[float]] = 0.5,
shuffle: bool = False,
random_sampler: ContinuousParameter = None,
transform_call: Callable[[Any, Callable], Any] = dict_call,
**kwargs):
"""
Args:
*transforms: one or multiple transformations which are applied in
consecutive order
dropout: if provided as float, each transform is skipped with the
given probability
if :attr:`dropout` is a sequence, it needs to specify the
dropout probability for each given transform
shuffle: apply transforms in random order
random_sampler : a continuous parameter sampler. Samples a
random value for each of the transforms.
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
Raises:
ValueError: if dropout is a sequence it must have the same length
as transforms
"""
super().__init__(*transforms, transform_call=transform_call,
shuffle=shuffle, **kwargs)
if random_sampler is None:
random_sampler = UniformParameter(0., 1.)
self.register_sampler('prob', random_sampler, size=(len(self.transforms),))
if check_scalar(dropout):
dropout = [dropout] * len(self.transforms)
self.dropout = dropout
if len(dropout) != len(self.transforms):
raise TypeError(f"If dropout is a sequence it must specify the "
f"dropout probability for each transform, "
f"found {len(dropout)} probabilities "
f"and {len(self.transforms)} transforms.")
[docs] def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]:
"""
Apply transforms in a consecutive order. Can either handle
Sequence like or Mapping like data.
Args:
*seq_like: data which is unpacked like a Sequence
**map_like: data which is unpacked like a dict
Returns:
Union[Sequence, Mapping]: dict with transformed data
"""
assert not (seq_like and map_like)
assert len(self.transforms) == len(self.transform_order)
data = seq_like if seq_like else map_like
rand = self.prob
for idx in self.transform_order:
if rand[idx] > self.dropout[idx]:
data = self.transform_call(data, self.transforms[idx])
return data
[docs]class OneOf(AbstractTransform):
"""
Apply one of the given transforms.
"""
def __init__(self, *transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
weights: Optional[Sequence[float]] = None,
p: float = 1.,
transform_call: Callable[[Any, Callable], Any] = dict_call):
"""
Args:
*transforms: transforms to choose from
weights: additional weights for transforms
p: probability that one transform i applied
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
"""
super().__init__(grad=True)
if len(transforms) > 0 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
if not transforms:
raise ValueError("At least one transformation needs to be selected.")
self.transforms = transforms
if weights is not None and len(weights) != len(transforms):
raise ValueError("If weights are porvided, every transform needs a weight. "
f"Found {len(weights)} weights and {len(transforms)} transforms")
if weights is None:
self.weights = torch.tensor(
[1 / len(self.transforms)] * len(self.transforms))
else:
self.weights = torch.tensor(weights)
self.p = p
self.transform_call = transform_call
[docs] def forward(self, **data) -> dict:
if torch.rand(1) < self.p:
index = torch.multinomial(self.weights, 1)
data = self.transform_call(data, self.transforms[int(index)])
return data