Source code for rising.transforms.crop
import torch
from typing import Sequence, Union
from rising.transforms.abstract import BaseTransform, BaseTransformSeeded
from rising.random import AbstractParameter
from rising.transforms.functional.crop import random_crop, center_crop
__all__ = ["CenterCrop", "RandomCrop"]
[docs]class CenterCrop(BaseTransform):
def __init__(self, size: Union[int, Sequence, AbstractParameter],
keys: Sequence = ('data',),
grad: bool = False, **kwargs):
"""
Args:
size: size of crop
keys: keys which should be augmented
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=center_crop, keys=keys,
grad=grad, property_names=('size', ), size=size,
**kwargs)
[docs]class RandomCrop(BaseTransformSeeded):
def __init__(self, size: Union[int, Sequence, AbstractParameter],
dist: Union[int, Sequence, AbstractParameter] = 0,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Args:
size: size of crop
dist: minimum distance to border. By default zero
keys: keys which should be augmented
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=random_crop,
keys=keys, size=size, dist=dist, grad=grad,
property_names=('size', 'dist'), **kwargs)