
Source code for rising.transforms.utility

from typing import Sequence, Mapping, Hashable, Union

import torch
from rising.transforms.abstract import AbstractTransform
from rising.transforms.functional.utility import seg_to_box, box_to_seg, instance_to_semantic

__all__ = ["DoNothing", "SegToBox", "BoxToSeg", "InstanceToSemantic"]

[docs]class DoNothing(AbstractTransform): """Transform that returns the input as is""" def __init__(self, grad: bool = False, **kwargs): """ Args: grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to superclass """ super().__init__(grad=grad, **kwargs)
[docs] def forward(self, **data) -> dict: """ Forward input Args: data: input dict Returns: input dict """ return data
[docs]class SegToBox(AbstractTransform): """Convert instance segmentation to bounding boxes""" def __init__(self, keys: Mapping[Hashable, Hashable], grad: bool = False, **kwargs): """ Args: keys: the key specifies which item to use as segmentation and the item specifies where the save the bounding boxes grad: enable gradient computation inside transformation """ super().__init__(grad=grad, **kwargs) self.keys = keys
[docs] def forward(self, **data) -> dict: """ Args: **data: input data Returns: dict: transformed data """ for source, target in self.keys.items(): data[target] = [seg_to_box(s, s.ndim - 2) for s in data[source].split(1)] return data
[docs]class BoxToSeg(AbstractTransform): """Convert bounding boxes to instance segmentation""" def __init__(self, keys: Mapping[Hashable, Hashable], shape: Sequence[int], dtype: torch.dtype, device: Union[torch.device, str], grad: bool = False, **kwargs): """ Args: keys: the key specifies which item to use as the bounding boxes and the item specifies where the save the bounding boxes shape: spatial shape of output tensor (batchsize is derived from bounding boxes and has one channel) dtype: dtype of segmentation device: device of segmentation grad: enable gradient computation inside transformation **kwargs: Additional keyword arguments forwarded to the Base Class """ super().__init__(grad=grad, **kwargs) self.keys = keys self.seg_shape = shape self.seg_dtype = dtype self.seg_device = device
[docs] def forward(self, **data) -> dict: """ Forward input Args: **data: input data Returns: dict: transformed data """ for source, target in self.keys.items(): out = torch.zeros((len(data[source]), 1, *self.seg_shape), dtype=self.seg_dtype, device=self.seg_device) for b in range(len(data[source])): box_to_seg(data[source][b], out=out[b]) data[target] = out return data
[docs]class InstanceToSemantic(AbstractTransform): """Convert an instance segmentation to a semantic segmentation""" def __init__(self, keys: Mapping[str, str], cls_key: Hashable, grad: bool = False, **kwargs): """ Args: keys: the key specifies which item to use as instance segmentation and the item specifies where the save the semantic segmentation cls_key: key where the class mapping is saved. Mapping needs to be a Sequence{Sequence[int]]. grad: enable gradient computation inside transformation """ super().__init__(grad=grad, **kwargs) self.cls_key = cls_key self.keys = keys
[docs] def forward(self, **data) -> dict: """ Forward input Args: **data: input data Returns: dict: transformed data """ for source, target in self.keys.items(): data[target] =[instance_to_semantic(data, mapping) for data, mapping in zip(data[source].split(1), data[self.cls_key])]) return data

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 871c839e.

Read the Docs v: v0.2.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.